This commit is contained in:
Jehiah Czebotar 2012-12-17 13:38:33 -05:00
parent 42359333b2
commit 42f539109e
7 changed files with 51 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/csv" "encoding/csv"
"io"
"log" "log"
"os" "os"
) )
@ -15,26 +16,30 @@ type HtpasswdFile struct {
Users map[string]string Users map[string]string
} }
func NewHtpasswdFile(path string) *HtpasswdFile { func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
log.Printf("using htpasswd file %s", path) log.Printf("using htpasswd file %s", path)
r, err := os.Open(path) r, err := os.Open(path)
if err != nil { if err != nil {
log.Fatalf("failed opening %v, %s", path, err.Error()) return nil, err
} }
csv_reader := csv.NewReader(r) return NewHtpasswd(r)
}
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
csv_reader := csv.NewReader(file)
csv_reader.Comma = ':' csv_reader.Comma = ':'
csv_reader.Comment = '#' csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true csv_reader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll() records, err := csv_reader.ReadAll()
if err != nil { if err != nil {
log.Fatalf("Failed reading file %s", err.Error()) return nil, err
} }
h := &HtpasswdFile{Users: make(map[string]string)} h := &HtpasswdFile{Users: make(map[string]string)}
for _, record := range records { for _, record := range records {
h.Users[record[0]] = record[1] h.Users[record[0]] = record[1]
} }
return h return h, nil
} }
func (h *HtpasswdFile) Validate(user string, password string) bool { func (h *HtpasswdFile) Validate(user string, password string) bool {

16
htpasswd_test.go Normal file
View File

@ -0,0 +1,16 @@
package main
import (
"bytes"
"github.com/bmizerany/assert"
"testing"
)
func TestHtpasswd(t *testing.T) {
file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n"))
h, err := NewHtpasswd(file)
assert.Equal(t, err, nil)
valid := h.Validate("testuser", "asdf")
assert.Equal(t, valid, true)
}

View File

@ -2,12 +2,12 @@ package main
import ( import (
"flag" "flag"
"fmt"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"fmt"
) )
const VERSION = "0.0.1" const VERSION = "0.0.1"
@ -72,7 +72,10 @@ func main() {
oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain) oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain)
} }
if *htpasswdFile != "" { if *htpasswdFile != "" {
oauthproxy.HtpasswdFile = NewHtpasswdFile(*htpasswdFile) oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile)
if err != nil {
log.Fatalf("FATAL: unable to open %s %s", *htpasswdFile, err.Error())
}
} }
listener, err := net.Listen("tcp", *httpAddr) listener, err := net.Listen("tcp", *httpAddr)
if err != nil { if err != nil {

View File

@ -169,11 +169,11 @@ func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
rw.WriteHeader(code) rw.WriteHeader(code)
templates := getTemplates() templates := getTemplates()
t := struct { t := struct {
Title string Title string
Message string Message string
}{ }{
Title: fmt.Sprintf("%d %s", code, title), Title: fmt.Sprintf("%d %s", code, title),
Message: message, Message: message,
} }
templates.ExecuteTemplate(rw, "error.html", t) templates.ExecuteTemplate(rw, "error.html", t)
} }
@ -254,7 +254,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return return
} }
} }
cookie, err := req.Cookie(p.CookieKey) cookie, err := req.Cookie(p.CookieKey)
var ok bool var ok bool
var email string var email string

View File

@ -18,7 +18,7 @@ func getTemplates() *template.Template {
if err != nil { if err != nil {
log.Fatalf("failed parsing template %s", err.Error()) log.Fatalf("failed parsing template %s", err.Error())
} }
t, err = t.Parse(`{{define "error.html"}} t, err = t.Parse(`{{define "error.html"}}
<html><head><title>{{.Title}}</title></head> <html><head><title>{{.Title}}</title></head>
<body> <body>

12
templates_test.go Normal file
View File

@ -0,0 +1,12 @@
package main
import (
"github.com/bmizerany/assert"
"testing"
)
func TestTemplatesCompile(t *testing.T) {
templates := getTemplates()
assert.NotEqual(t, templates, nil)
}

View File

@ -1,10 +1,10 @@
package main package main
import ( import (
"os"
"log"
"encoding/csv" "encoding/csv"
"fmt" "fmt"
"log"
"os"
"strings" "strings"
) )