diff --git a/htpasswd.go b/htpasswd.go index 8525c71..a620d58 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/base64" "encoding/csv" + "io" "log" "os" ) @@ -15,26 +16,30 @@ type HtpasswdFile struct { Users map[string]string } -func NewHtpasswdFile(path string) *HtpasswdFile { +func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { log.Printf("using htpasswd file %s", path) r, err := os.Open(path) if err != nil { - log.Fatalf("failed opening %v, %s", path, err.Error()) + return nil, err } - csv_reader := csv.NewReader(r) + return NewHtpasswd(r) +} + +func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { + csv_reader := csv.NewReader(file) csv_reader.Comma = ':' csv_reader.Comment = '#' csv_reader.TrimLeadingSpace = true records, err := csv_reader.ReadAll() if err != nil { - log.Fatalf("Failed reading file %s", err.Error()) + return nil, err } h := &HtpasswdFile{Users: make(map[string]string)} for _, record := range records { h.Users[record[0]] = record[1] } - return h + return h, nil } func (h *HtpasswdFile) Validate(user string, password string) bool { diff --git a/htpasswd_test.go b/htpasswd_test.go new file mode 100644 index 0000000..5cfc9e6 --- /dev/null +++ b/htpasswd_test.go @@ -0,0 +1,16 @@ +package main + +import ( + "bytes" + "github.com/bmizerany/assert" + "testing" +) + +func TestHtpasswd(t *testing.T) { + file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n")) + h, err := NewHtpasswd(file) + assert.Equal(t, err, nil) + + valid := h.Validate("testuser", "asdf") + assert.Equal(t, valid, true) +} diff --git a/main.go b/main.go index 92726ad..e02bd3e 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,12 @@ package main import ( "flag" + "fmt" "log" "net" "net/http" "net/url" "strings" - "fmt" ) const VERSION = "0.0.1" @@ -72,7 +72,10 @@ func main() { oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain) } if *htpasswdFile != "" { - oauthproxy.HtpasswdFile = NewHtpasswdFile(*htpasswdFile) + oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile) + if err != nil { + log.Fatalf("FATAL: unable to open %s %s", *htpasswdFile, err.Error()) + } } listener, err := net.Listen("tcp", *httpAddr) if err != nil { diff --git a/oauthproxy.go b/oauthproxy.go index ffdda9e..44ede97 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -169,11 +169,11 @@ func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m rw.WriteHeader(code) templates := getTemplates() t := struct { - Title string - Message string + Title string + Message string }{ - Title: fmt.Sprintf("%d %s", code, title), - Message: message, + Title: fmt.Sprintf("%d %s", code, title), + Message: message, } templates.ExecuteTemplate(rw, "error.html", t) } @@ -254,7 +254,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } } - + cookie, err := req.Cookie(p.CookieKey) var ok bool var email string diff --git a/templates.go b/templates.go index ee28eac..3a88d36 100644 --- a/templates.go +++ b/templates.go @@ -18,7 +18,7 @@ func getTemplates() *template.Template { if err != nil { log.Fatalf("failed parsing template %s", err.Error()) } - + t, err = t.Parse(`{{define "error.html"}}