This commit is contained in:
Jehiah Czebotar 2012-12-17 13:38:33 -05:00
parent 42359333b2
commit 42f539109e
7 changed files with 51 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/csv" "encoding/csv"
"io"
"log" "log"
"os" "os"
) )
@ -15,26 +16,30 @@ type HtpasswdFile struct {
Users map[string]string Users map[string]string
} }
func NewHtpasswdFile(path string) *HtpasswdFile { func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
log.Printf("using htpasswd file %s", path) log.Printf("using htpasswd file %s", path)
r, err := os.Open(path) r, err := os.Open(path)
if err != nil { if err != nil {
log.Fatalf("failed opening %v, %s", path, err.Error()) return nil, err
} }
csv_reader := csv.NewReader(r) return NewHtpasswd(r)
}
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
csv_reader := csv.NewReader(file)
csv_reader.Comma = ':' csv_reader.Comma = ':'
csv_reader.Comment = '#' csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true csv_reader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll() records, err := csv_reader.ReadAll()
if err != nil { if err != nil {
log.Fatalf("Failed reading file %s", err.Error()) return nil, err
} }
h := &HtpasswdFile{Users: make(map[string]string)} h := &HtpasswdFile{Users: make(map[string]string)}
for _, record := range records { for _, record := range records {
h.Users[record[0]] = record[1] h.Users[record[0]] = record[1]
} }
return h return h, nil
} }
func (h *HtpasswdFile) Validate(user string, password string) bool { func (h *HtpasswdFile) Validate(user string, password string) bool {

16
htpasswd_test.go Normal file
View File

@ -0,0 +1,16 @@
package main
import (
"bytes"
"github.com/bmizerany/assert"
"testing"
)
func TestHtpasswd(t *testing.T) {
file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n"))
h, err := NewHtpasswd(file)
assert.Equal(t, err, nil)
valid := h.Validate("testuser", "asdf")
assert.Equal(t, valid, true)
}

View File

@ -2,12 +2,12 @@ package main
import ( import (
"flag" "flag"
"fmt"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"fmt"
) )
const VERSION = "0.0.1" const VERSION = "0.0.1"
@ -72,7 +72,10 @@ func main() {
oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain) oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain)
} }
if *htpasswdFile != "" { if *htpasswdFile != "" {
oauthproxy.HtpasswdFile = NewHtpasswdFile(*htpasswdFile) oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile)
if err != nil {
log.Fatalf("FATAL: unable to open %s %s", *htpasswdFile, err.Error())
}
} }
listener, err := net.Listen("tcp", *httpAddr) listener, err := net.Listen("tcp", *httpAddr)
if err != nil { if err != nil {

12
templates_test.go Normal file
View File

@ -0,0 +1,12 @@
package main
import (
"github.com/bmizerany/assert"
"testing"
)
func TestTemplatesCompile(t *testing.T) {
templates := getTemplates()
assert.NotEqual(t, templates, nil)
}

View File

@ -1,10 +1,10 @@
package main package main
import ( import (
"os"
"log"
"encoding/csv" "encoding/csv"
"fmt" "fmt"
"log"
"os"
"strings" "strings"
) )