testing
This commit is contained in:
parent
42359333b2
commit
42f539109e
15
htpasswd.go
15
htpasswd.go
@ -4,6 +4,7 @@ import (
|
|||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
@ -15,26 +16,30 @@ type HtpasswdFile struct {
|
|||||||
Users map[string]string
|
Users map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHtpasswdFile(path string) *HtpasswdFile {
|
func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
|
||||||
log.Printf("using htpasswd file %s", path)
|
log.Printf("using htpasswd file %s", path)
|
||||||
r, err := os.Open(path)
|
r, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed opening %v, %s", path, err.Error())
|
return nil, err
|
||||||
}
|
}
|
||||||
csv_reader := csv.NewReader(r)
|
return NewHtpasswd(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
|
||||||
|
csv_reader := csv.NewReader(file)
|
||||||
csv_reader.Comma = ':'
|
csv_reader.Comma = ':'
|
||||||
csv_reader.Comment = '#'
|
csv_reader.Comment = '#'
|
||||||
csv_reader.TrimLeadingSpace = true
|
csv_reader.TrimLeadingSpace = true
|
||||||
|
|
||||||
records, err := csv_reader.ReadAll()
|
records, err := csv_reader.ReadAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed reading file %s", err.Error())
|
return nil, err
|
||||||
}
|
}
|
||||||
h := &HtpasswdFile{Users: make(map[string]string)}
|
h := &HtpasswdFile{Users: make(map[string]string)}
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
h.Users[record[0]] = record[1]
|
h.Users[record[0]] = record[1]
|
||||||
}
|
}
|
||||||
return h
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HtpasswdFile) Validate(user string, password string) bool {
|
func (h *HtpasswdFile) Validate(user string, password string) bool {
|
||||||
|
16
htpasswd_test.go
Normal file
16
htpasswd_test.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/bmizerany/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHtpasswd(t *testing.T) {
|
||||||
|
file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n"))
|
||||||
|
h, err := NewHtpasswd(file)
|
||||||
|
assert.Equal(t, err, nil)
|
||||||
|
|
||||||
|
valid := h.Validate("testuser", "asdf")
|
||||||
|
assert.Equal(t, valid, true)
|
||||||
|
}
|
7
main.go
7
main.go
@ -2,12 +2,12 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"fmt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const VERSION = "0.0.1"
|
const VERSION = "0.0.1"
|
||||||
@ -72,7 +72,10 @@ func main() {
|
|||||||
oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain)
|
oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain)
|
||||||
}
|
}
|
||||||
if *htpasswdFile != "" {
|
if *htpasswdFile != "" {
|
||||||
oauthproxy.HtpasswdFile = NewHtpasswdFile(*htpasswdFile)
|
oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("FATAL: unable to open %s %s", *htpasswdFile, err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
listener, err := net.Listen("tcp", *httpAddr)
|
listener, err := net.Listen("tcp", *httpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -169,11 +169,11 @@ func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
|||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
templates := getTemplates()
|
templates := getTemplates()
|
||||||
t := struct {
|
t := struct {
|
||||||
Title string
|
Title string
|
||||||
Message string
|
Message string
|
||||||
}{
|
}{
|
||||||
Title: fmt.Sprintf("%d %s", code, title),
|
Title: fmt.Sprintf("%d %s", code, title),
|
||||||
Message: message,
|
Message: message,
|
||||||
}
|
}
|
||||||
templates.ExecuteTemplate(rw, "error.html", t)
|
templates.ExecuteTemplate(rw, "error.html", t)
|
||||||
}
|
}
|
||||||
|
12
templates_test.go
Normal file
12
templates_test.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bmizerany/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTemplatesCompile(t *testing.T) {
|
||||||
|
templates := getTemplates()
|
||||||
|
assert.NotEqual(t, templates, nil)
|
||||||
|
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"log"
|
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user