diff --git a/Godeps b/Godeps index 41bef4c..6a77c0f 100644 --- a/Godeps +++ b/Godeps @@ -2,3 +2,4 @@ github.com/BurntSushi/toml 3883ac1ce943878302255f538fce319d23226223 github.com/bitly/go-simplejson 3378bdcb5cebedcbf8b5750edee28010f128fe24 github.com/mreiferson/go-options ee94b57f2fbf116075426f853e5abbcdfeca8b3d github.com/bmizerany/assert e17e99893cb6509f428e1728281c2ad60a6b31e3 +gopkg.in/fsnotify.v1 v1.2.0 diff --git a/api/api_test.go b/api/api_test.go index 2494131..8327a28 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -40,8 +40,8 @@ func TestRequestFailure(t *testing.T) { resp, err := Request(req) assert.Equal(t, (*simplejson.Json)(nil), resp) assert.NotEqual(t, nil, err) - if !strings.HasSuffix(err.Error(), "connection refused") { - t.Error("expected error when a connection fails") + if !strings.Contains(err.Error(), "refused") { + t.Error("expected error when a connection fails: ", err) } } diff --git a/htpasswd.go b/htpasswd.go index 5988c7c..6fd888e 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -21,6 +21,7 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { if err != nil { return nil, err } + defer r.Close() return NewHtpasswd(r) } diff --git a/validator.go b/validator.go index f85e1d4..8c3e985 100644 --- a/validator.go +++ b/validator.go @@ -6,43 +6,84 @@ import ( "log" "os" "strings" + "sync/atomic" + "unsafe" ) -func NewValidator(domains []string, usersFile string) func(string) bool { - validUsers := make(map[string]bool) +type UserMap struct { + usersFile string + m unsafe.Pointer +} +func NewUserMap(usersFile string, onUpdate func()) *UserMap { + um := &UserMap{usersFile: usersFile} + m := make(map[string]bool) + atomic.StorePointer(&um.m, unsafe.Pointer(&m)) if usersFile != "" { log.Printf("using authenticated emails file %s", usersFile) - r, err := os.Open(usersFile) - if err != nil { - log.Fatalf("failed opening authenticated-emails-file=%q, %s", usersFile, err) - } - csv_reader := csv.NewReader(r) - csv_reader.Comma = ',' - csv_reader.Comment = '#' - csv_reader.TrimLeadingSpace = true - records, err := csv_reader.ReadAll() - for _, r := range records { - validUsers[strings.ToLower(r[0])] = true + started := WatchForUpdates(usersFile, func() { + um.LoadAuthenticatedEmailsFile() + onUpdate() + }) + if started { + log.Printf("watching %s for updates", usersFile) } + um.LoadAuthenticatedEmailsFile() } + return um +} + +func (um *UserMap) IsValid(email string) (result bool) { + m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) + _, result = m[email] + return +} + +func (um *UserMap) LoadAuthenticatedEmailsFile() { + r, err := os.Open(um.usersFile) + if err != nil { + log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) + } + defer r.Close() + csv_reader := csv.NewReader(r) + csv_reader.Comma = ',' + csv_reader.Comment = '#' + csv_reader.TrimLeadingSpace = true + records, err := csv_reader.ReadAll() + if err != nil { + log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) + return + } + updated := make(map[string]bool) + for _, r := range records { + updated[strings.ToLower(r[0])] = true + } + atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) +} + +func newValidatorImpl(domains []string, usersFile string, + onUpdate func()) func(string) bool { + validUsers := NewUserMap(usersFile, onUpdate) for i, domain := range domains { - domains[i] = strings.ToLower(domain) + domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) } validator := func(email string) bool { email = strings.ToLower(email) valid := false for _, domain := range domains { - emailSuffix := fmt.Sprintf("@%s", domain) - valid = valid || strings.HasSuffix(email, emailSuffix) + valid = valid || strings.HasSuffix(email, domain) } if !valid { - _, valid = validUsers[email] + valid = validUsers.IsValid(email) } log.Printf("validating: is %s valid? %v", email, valid) return valid } return validator } + +func NewValidator(domains []string, usersFile string) func(string) bool { + return newValidatorImpl(domains, usersFile, func() {}) +} diff --git a/validator_test.go b/validator_test.go index 3c223bd..0912b36 100644 --- a/validator_test.go +++ b/validator_test.go @@ -7,23 +7,117 @@ import ( "testing" ) -func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { - auth_email_file, err := ioutil.TempFile("", "test_auth_emails_") +type ValidatorTest struct { + auth_email_file *os.File +} + +func NewValidatorTest(t *testing.T) *ValidatorTest { + vt := &ValidatorTest{} + var err error + vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file: " + err.Error()) } - defer os.Remove(auth_email_file.Name()) + return vt +} - auth_email_file.WriteString( - strings.Join([]string{"Foo.Bar@Example.Com"}, "\n")) - err = auth_email_file.Close() - if err != nil { - t.Fatal("failed to close temp file " + auth_email_file.Name() + - ": " + err.Error()) +func (vt *ValidatorTest) TearDown() { + os.Remove(vt.auth_email_file.Name()) +} + +// This will close vt.auth_email_file. +func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { + defer vt.auth_email_file.Close() + vt.auth_email_file.WriteString(strings.Join(emails, "\n")) + if err := vt.auth_email_file.Close(); err != nil { + t.Fatal("failed to close temp file " + + vt.auth_email_file.Name() + ": " + err.Error()) } +} +func TestValidatorEmpty(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string(nil)) + domains := []string(nil) + validator := NewValidator(domains, vt.auth_email_file.Name()) + + if validator("foo.bar@example.com") { + t.Error("nothing should validate when the email and " + + "domain lists are empty") + } +} + +func TestValidatorSingleEmail(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string{"foo.bar@example.com"}) + domains := []string(nil) + validator := NewValidator(domains, vt.auth_email_file.Name()) + + if !validator("foo.bar@example.com") { + t.Error("email should validate") + } + if validator("baz.quux@example.com") { + t.Error("email from same domain but not in list " + + "should not validate when domain list is empty") + } +} + +func TestValidatorSingleDomain(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string(nil)) + domains := []string{"example.com"} + validator := NewValidator(domains, vt.auth_email_file.Name()) + + if !validator("foo.bar@example.com") { + t.Error("email should validate") + } + if !validator("baz.quux@example.com") { + t.Error("email from same domain should validate") + } +} + +func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string{ + "xyzzy@example.com", + "plugh@example.com", + }) + domains := []string{"example0.com", "example1.com"} + validator := NewValidator(domains, vt.auth_email_file.Name()) + + if !validator("foo.bar@example0.com") { + t.Error("email from first domain should validate") + } + if !validator("baz.quux@example1.com") { + t.Error("email from second domain should validate") + } + if !validator("xyzzy@example.com") { + t.Error("first email in list should validate") + } + if !validator("plugh@example.com") { + t.Error("second email in list should validate") + } + if validator("xyzzy.plugh@example.com") { + t.Error("email not in list that matches no domains " + + "should not validate") + } +} + +func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) domains := []string{"Frobozz.Com"} - validator := NewValidator(domains, auth_email_file.Name()) + validator := NewValidator(domains, vt.auth_email_file.Name()) if !validator("foo.bar@example.com") { t.Error("loaded email addresses are not lower-cased") diff --git a/validator_watcher_copy_test.go b/validator_watcher_copy_test.go new file mode 100644 index 0000000..35bed3a --- /dev/null +++ b/validator_watcher_copy_test.go @@ -0,0 +1,50 @@ +// +build go1.3 +// +build !plan9,!solaris,!windows + +// Turns out you can't copy over an existing file on Windows. + +package main + +import ( + "io/ioutil" + "os" + "testing" +) + +func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( + t *testing.T, emails []string) { + orig_file := vt.auth_email_file + var err error + vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") + if err != nil { + t.Fatal("failed to create temp file for copy: " + err.Error()) + } + vt.WriteEmails(t, emails) + err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) + if err != nil { + t.Fatal("failed to copy over temp file: " + err.Error()) + } + vt.auth_email_file = orig_file +} + +func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string{"xyzzy@example.com"}) + domains := []string(nil) + updated := make(chan bool) + validator := newValidatorImpl(domains, vt.auth_email_file.Name(), + func() { updated <- true }) + + if !validator("xyzzy@example.com") { + t.Error("email in list should validate") + } + + vt.UpdateEmailFileViaCopyingOver(t, []string{"plugh@example.com"}) + <-updated + + if validator("xyzzy@example.com") { + t.Error("email removed from list should not validate") + } +} diff --git a/validator_watcher_test.go b/validator_watcher_test.go new file mode 100644 index 0000000..691644f --- /dev/null +++ b/validator_watcher_test.go @@ -0,0 +1,105 @@ +// +build go1.3 +// +build !plan9,!solaris + +package main + +import ( + "io/ioutil" + "os" + "testing" +) + +func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { + var err error + vt.auth_email_file, err = os.OpenFile( + vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) + if err != nil { + t.Fatal("failed to re-open temp file for updates") + } + vt.WriteEmails(t, emails) +} + +func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( + t *testing.T, emails []string) { + orig_file := vt.auth_email_file + var err error + vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") + if err != nil { + t.Fatal("failed to create temp file for rename and replace: " + + err.Error()) + } + vt.WriteEmails(t, emails) + + moved_name := orig_file.Name() + "-moved" + err = os.Rename(orig_file.Name(), moved_name) + err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) + if err != nil { + t.Fatal("failed to rename and replace temp file: " + + err.Error()) + } + vt.auth_email_file = orig_file + os.Remove(moved_name) +} + +func TestValidatorOverwriteEmailListDirectly(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string{ + "xyzzy@example.com", + "plugh@example.com", + }) + domains := []string(nil) + updated := make(chan bool) + validator := newValidatorImpl(domains, vt.auth_email_file.Name(), + func() { updated <- true }) + + if !validator("xyzzy@example.com") { + t.Error("first email in list should validate") + } + if !validator("plugh@example.com") { + t.Error("second email in list should validate") + } + if validator("xyzzy.plugh@example.com") { + t.Error("email not in list that matches no domains " + + "should not validate") + } + + vt.UpdateEmailFile(t, []string{ + "xyzzy.plugh@example.com", + "plugh@example.com", + }) + <-updated + + if validator("xyzzy@example.com") { + t.Error("email removed from list should not validate") + } + if !validator("plugh@example.com") { + t.Error("email retained in list should validate") + } + if !validator("xyzzy.plugh@example.com") { + t.Error("email added to list should validate") + } +} + +func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) { + vt := NewValidatorTest(t) + defer vt.TearDown() + + vt.WriteEmails(t, []string{"xyzzy@example.com"}) + domains := []string(nil) + updated := make(chan bool) + validator := newValidatorImpl(domains, vt.auth_email_file.Name(), + func() { updated <- true }) + + if !validator("xyzzy@example.com") { + t.Error("email in list should validate") + } + + vt.UpdateEmailFileViaRenameAndReplace(t, []string{"plugh@example.com"}) + <-updated + + if validator("xyzzy@example.com") { + t.Error("email removed from list should not validate") + } +} diff --git a/watcher.go b/watcher.go new file mode 100644 index 0000000..32c91f2 --- /dev/null +++ b/watcher.go @@ -0,0 +1,64 @@ +// +build go1.3 +// +build !plan9,!solaris + +package main + +import ( + "log" + "os" + "path/filepath" + "time" + + "gopkg.in/fsnotify.v1" +) + +func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) { + const sleep_interval = 50 * time.Millisecond + + // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. + if event.Op&fsnotify.Chmod != 0 { + time.Sleep(sleep_interval) + } + for { + if _, err := os.Stat(event.Name); err == nil { + if err := watcher.Add(event.Name); err == nil { + log.Printf("watching resumed for %s", event.Name) + return + } + } + time.Sleep(sleep_interval) + } +} + +func WatchForUpdates(filename string, action func()) bool { + filename = filepath.Clean(filename) + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal("failed to create watcher for ", filename, ": ", err) + } + go func() { + defer watcher.Close() + for { + select { + case event := <-watcher.Events: + // On Arch Linux, it appears Chmod events precede Remove events, + // which causes a race between action() and the coming Remove event. + // If the Remove wins, the action() (which calls + // UserMap.LoadAuthenticatedEmailsFile()) crashes when the file + // can't be opened. + if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { + log.Printf("watching interrupted on event: %s", event) + WaitForReplacement(event, watcher) + } + log.Printf("reloading after event: %s", event) + action() + case err := <-watcher.Errors: + log.Printf("error watching %s: %s", filename, err) + } + } + }() + if err = watcher.Add(filename); err != nil { + log.Fatal("failed to add ", filename, " to watcher: ", err) + } + return true +} diff --git a/watcher_unsupported.go b/watcher_unsupported.go new file mode 100644 index 0000000..d94e11e --- /dev/null +++ b/watcher_unsupported.go @@ -0,0 +1,13 @@ +// +build go1.1 +// +build plan9,solaris + +package main + +import ( + "log" +) + +func WatchForUpdates(filename string, action func()) bool { + log.Printf("file watching not implemented on this platform") + return false +}