This commit is contained in:
Jehiah Czebotar 2012-12-17 13:38:33 -05:00
parent 42359333b2
commit 42f539109e
7 changed files with 51 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/sha1"
"encoding/base64"
"encoding/csv"
"io"
"log"
"os"
)
@ -15,26 +16,30 @@ type HtpasswdFile struct {
Users map[string]string
}
func NewHtpasswdFile(path string) *HtpasswdFile {
func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
log.Printf("using htpasswd file %s", path)
r, err := os.Open(path)
if err != nil {
log.Fatalf("failed opening %v, %s", path, err.Error())
return nil, err
}
csv_reader := csv.NewReader(r)
return NewHtpasswd(r)
}
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
csv_reader := csv.NewReader(file)
csv_reader.Comma = ':'
csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll()
if err != nil {
log.Fatalf("Failed reading file %s", err.Error())
return nil, err
}
h := &HtpasswdFile{Users: make(map[string]string)}
for _, record := range records {
h.Users[record[0]] = record[1]
}
return h
return h, nil
}
func (h *HtpasswdFile) Validate(user string, password string) bool {

16
htpasswd_test.go Normal file
View File

@ -0,0 +1,16 @@
package main
import (
"bytes"
"github.com/bmizerany/assert"
"testing"
)
func TestHtpasswd(t *testing.T) {
file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n"))
h, err := NewHtpasswd(file)
assert.Equal(t, err, nil)
valid := h.Validate("testuser", "asdf")
assert.Equal(t, valid, true)
}

View File

@ -2,12 +2,12 @@ package main
import (
"flag"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strings"
"fmt"
)
const VERSION = "0.0.1"
@ -72,7 +72,10 @@ func main() {
oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain)
}
if *htpasswdFile != "" {
oauthproxy.HtpasswdFile = NewHtpasswdFile(*htpasswdFile)
oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile)
if err != nil {
log.Fatalf("FATAL: unable to open %s %s", *htpasswdFile, err.Error())
}
}
listener, err := net.Listen("tcp", *httpAddr)
if err != nil {

View File

@ -169,11 +169,11 @@ func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
rw.WriteHeader(code)
templates := getTemplates()
t := struct {
Title string
Message string
Title string
Message string
}{
Title: fmt.Sprintf("%d %s", code, title),
Message: message,
Title: fmt.Sprintf("%d %s", code, title),
Message: message,
}
templates.ExecuteTemplate(rw, "error.html", t)
}

12
templates_test.go Normal file
View File

@ -0,0 +1,12 @@
package main
import (
"github.com/bmizerany/assert"
"testing"
)
func TestTemplatesCompile(t *testing.T) {
templates := getTemplates()
assert.NotEqual(t, templates, nil)
}

View File

@ -1,10 +1,10 @@
package main
import (
"os"
"log"
"encoding/csv"
"fmt"
"log"
"os"
"strings"
)