diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..3cd88d2 --- /dev/null +++ b/api/api.go @@ -0,0 +1,32 @@ +package api + +import ( + "errors" + "io/ioutil" + "log" + "net/http" + + "github.com/bitly/go-simplejson" +) + +func Request(req *http.Request) (*simplejson.Json, error) { + httpclient := &http.Client{} + resp, err := httpclient.Do(req) + if err != nil { + return nil, err + } + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + if resp.StatusCode != 200 { + log.Printf("got response code %d - %s", resp.StatusCode, body) + return nil, errors.New("api request returned non 200 status code") + } + data, err := simplejson.NewJson(body) + if err != nil { + return nil, err + } + return data, nil +} diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 0000000..2494131 --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,68 @@ +package api + +import ( + "github.com/bitly/go-simplejson" + "github.com/bmizerany/assert" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func testBackend(response_code int, payload string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(response_code) + w.Write([]byte(payload)) + })) +} + +func TestRequest(t *testing.T) { + backend := testBackend(200, "{\"foo\": \"bar\"}") + defer backend.Close() + + req, _ := http.NewRequest("GET", backend.URL, nil) + response, err := Request(req) + assert.Equal(t, nil, err) + result, err := response.Get("foo").String() + assert.Equal(t, nil, err) + assert.Equal(t, "bar", result) +} + +func TestRequestFailure(t *testing.T) { + // Create a backend to generate a test URL, then close it to cause a + // connection error. + backend := testBackend(200, "{\"foo\": \"bar\"}") + backend.Close() + + req, err := http.NewRequest("GET", backend.URL, nil) + assert.Equal(t, nil, err) + resp, err := Request(req) + assert.Equal(t, (*simplejson.Json)(nil), resp) + assert.NotEqual(t, nil, err) + if !strings.HasSuffix(err.Error(), "connection refused") { + t.Error("expected error when a connection fails") + } +} + +func TestHttpErrorCode(t *testing.T) { + backend := testBackend(404, "{\"foo\": \"bar\"}") + defer backend.Close() + + req, err := http.NewRequest("GET", backend.URL, nil) + assert.Equal(t, nil, err) + resp, err := Request(req) + assert.Equal(t, (*simplejson.Json)(nil), resp) + assert.NotEqual(t, nil, err) +} + +func TestJsonParsingError(t *testing.T) { + backend := testBackend(200, "not well-formed JSON") + defer backend.Close() + + req, err := http.NewRequest("GET", backend.URL, nil) + assert.Equal(t, nil, err) + resp, err := Request(req) + assert.Equal(t, (*simplejson.Json)(nil), resp) + assert.NotEqual(t, nil, err) +} diff --git a/oauthproxy.go b/oauthproxy.go index e089b55..a5bd71c 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "html/template" - "io/ioutil" "log" "net" "net/http" @@ -16,6 +15,7 @@ import ( "strings" "time" + "github.com/18F/google_auth_proxy/api" "github.com/bitly/go-simplejson" ) @@ -171,28 +171,6 @@ func (p *OauthProxy) GetLoginURL(host, redirect string) string { return fmt.Sprintf("%s?%s", p.oauthLoginUrl, params.Encode()) } -func apiRequest(req *http.Request) (*simplejson.Json, error) { - httpclient := &http.Client{} - resp, err := httpclient.Do(req) - if err != nil { - return nil, err - } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, err - } - if resp.StatusCode != 200 { - log.Printf("got response code %d - %s", resp.StatusCode, body) - return nil, errors.New("api request returned non 200 status code") - } - data, err := simplejson.NewJson(body) - if err != nil { - return nil, err - } - return data, nil -} - func (p *OauthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } @@ -213,7 +191,7 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) { return "", "", err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - json, err := apiRequest(req) + json, err := api.Request(req) if err != nil { log.Printf("failed making request %s", err) return "", "", err