Merge pull request #108 from jehiah/unmarshal_error_108

Page defaults to Google sign in
This commit is contained in:
Jehiah Czebotar 2015-06-07 21:06:50 -04:00
commit 5a5d6dff7e
7 changed files with 61 additions and 61 deletions

View File

@ -184,7 +184,7 @@ OAuth2 Proxy responds directly to the following endpoints. All other endpoints w
## Logging Format ## Logging Format
OAuth2 Proxy Proxy logs requests to stdout in a format similar to Apache Combined Log. OAuth2 Proxy logs requests to stdout in a format similar to Apache Combined Log.
``` ```
<REMOTE_ADDRESS> - <user@domain.com> [19/Mar/2015:17:20:19 -0400] <HOST_HEADER> GET <UPSTREAM_HOST> "/path/" HTTP/1.1 "<USER_AGENT>" <RESPONSE_CODE> <RESPONSE_BYTES> <REQUEST_DURATION> <REMOTE_ADDRESS> - <user@domain.com> [19/Mar/2015:17:20:19 -0400] <HOST_HEADER> GET <UPSTREAM_HOST> "/path/" HTTP/1.1 "<USER_AGENT>" <RESPONSE_CODE> <RESPONSE_BYTES> <REQUEST_DURATION>

View File

@ -1,9 +1,8 @@
package api package api
import ( import (
"errors" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"github.com/bitly/go-simplejson" "github.com/bitly/go-simplejson"
@ -20,8 +19,7 @@ func Request(req *http.Request) (*simplejson.Json, error) {
return nil, err return nil, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
log.Printf("got response code %d - %s", resp.StatusCode, body) return nil, fmt.Errorf("got %d %s", resp.StatusCode, body)
return nil, errors.New("api request returned non 200 status code")
} }
data, err := simplejson.NewJson(body) data, err := simplejson.NewJson(body)
if err != nil { if err != nil {
@ -30,19 +28,12 @@ func Request(req *http.Request) (*simplejson.Json, error) {
return data, nil return data, nil
} }
func RequestUnparsedResponse(url string, header http.Header) ( func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) {
response *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return nil, errors.New("failed building request for " + return nil, err
url + ": " + err.Error())
} }
req.Header = header req.Header = header
httpclient := &http.Client{} return http.DefaultClient.Do(req)
if response, err = httpclient.Do(req); err != nil {
return nil, errors.New("request failed for " +
url + ": " + err.Error())
}
return
} }

View File

@ -37,11 +37,6 @@ type OauthProxy struct {
redirectUrl *url.URL // the url to receive requests at redirectUrl *url.URL // the url to receive requests at
provider providers.Provider provider providers.Provider
oauthLoginUrl *url.URL // to redirect the user to
oauthValidateUrl *url.URL // to validate the access token
oauthScope string
clientID string
clientSecret string
ProxyPrefix string ProxyPrefix string
SignInMessage string SignInMessage string
HtpasswdFile *HtpasswdFile HtpasswdFile *HtpasswdFile
@ -147,25 +142,20 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
OauthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), OauthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix),
OauthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix), OauthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix),
clientID: opts.ClientID, ProxyPrefix: opts.ProxyPrefix,
clientSecret: opts.ClientSecret, provider: opts.provider,
ProxyPrefix: opts.ProxyPrefix, serveMux: serveMux,
oauthScope: opts.provider.Data().Scope, redirectUrl: redirectUrl,
provider: opts.provider, skipAuthRegex: opts.SkipAuthRegex,
oauthLoginUrl: opts.provider.Data().LoginUrl, compiledRegex: opts.CompiledRegex,
oauthValidateUrl: opts.provider.Data().ValidateUrl, PassBasicAuth: opts.PassBasicAuth,
serveMux: serveMux, PassAccessToken: opts.PassAccessToken,
redirectUrl: redirectUrl, AesCipher: aes_cipher,
skipAuthRegex: opts.SkipAuthRegex, templates: loadTemplates(opts.CustomTemplatesDir),
compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth,
PassAccessToken: opts.PassAccessToken,
AesCipher: aes_cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
} }
} }
func (p *OauthProxy) GetRedirectUrl(host string) string { func (p *OauthProxy) GetRedirectURI(host string) string {
// default to the request Host if not set // default to the request Host if not set
if p.redirectUrl.Host != "" { if p.redirectUrl.Host != "" {
return p.redirectUrl.String() return p.redirectUrl.String()
@ -183,19 +173,6 @@ func (p *OauthProxy) GetRedirectUrl(host string) string {
return u.String() return u.String()
} }
func (p *OauthProxy) GetLoginURL(host, redirect string) string {
params := url.Values{}
params.Add("redirect_uri", p.GetRedirectUrl(host))
params.Add("approval_prompt", "force")
params.Add("scope", p.oauthScope)
params.Add("client_id", p.clientID)
params.Add("response_type", "code")
if strings.HasPrefix(redirect, "/") {
params.Add("state", redirect)
}
return fmt.Sprintf("%s?%s", p.oauthLoginUrl, params.Encode())
}
func (p *OauthProxy) displayCustomLoginForm() bool { func (p *OauthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
} }
@ -204,7 +181,7 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) {
if code == "" { if code == "" {
return "", "", errors.New("missing code") return "", "", errors.New("missing code")
} }
redirectUri := p.GetRedirectUrl(host) redirectUri := p.GetRedirectURI(host)
body, access_token, err := p.provider.Redeem(redirectUri, code) body, access_token, err := p.provider.Redeem(redirectUri, code)
if err != nil { if err != nil {
return "", "", err return "", "", err
@ -416,7 +393,8 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, 500, "Internal Error", err.Error()) p.ErrorPage(rw, 500, "Internal Error", err.Error())
return return
} }
http.Redirect(rw, req, p.GetLoginURL(req.Host, redirect), 302) redirectURI := p.GetRedirectURI(req.Host)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
return return
} }
if req.URL.Path == p.OauthCallbackPath { if req.URL.Path == p.OauthCallbackPath {

View File

@ -2,6 +2,7 @@ package providers
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@ -58,10 +59,11 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
params := url.Values{ params := url.Values{
"access_token": {accessToken}, "access_token": {accessToken},
"limit": {"100"}, "limit": {"100"},
} }
req, _ := http.NewRequest("GET", "https://api.github.com/user/orgs?"+params.Encode(), nil) endpoint := "https://api.github.com/user/orgs?" + params.Encode()
req, _ := http.NewRequest("GET", endpoint, nil)
req.Header.Set("Accept", "application/vnd.github.moondragon+json") req.Header.Set("Accept", "application/vnd.github.moondragon+json")
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -73,6 +75,9 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
if resp.StatusCode != 200 {
return false, fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
}
if err := json.Unmarshal(body, &orgs); err != nil { if err := json.Unmarshal(body, &orgs); err != nil {
return false, err return false, err
@ -99,10 +104,11 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
params := url.Values{ params := url.Values{
"access_token": {accessToken}, "access_token": {accessToken},
"limit": {"100"}, "limit": {"100"},
} }
req, _ := http.NewRequest("GET", "https://api.github.com/user/teams?"+params.Encode(), nil) endpoint := "https://api.github.com/user/teams?" + params.Encode()
req, _ := http.NewRequest("GET", endpoint, nil)
req.Header.Set("Accept", "application/vnd.github.moondragon+json") req.Header.Set("Accept", "application/vnd.github.moondragon+json")
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -114,9 +120,12 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
if resp.StatusCode != 200 {
return false, fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
}
if err := json.Unmarshal(body, &teams); err != nil { if err := json.Unmarshal(body, &teams); err != nil {
return false, err return false, fmt.Errorf("%s unmarshaling %s", err, body)
} }
for _, team := range teams { for _, team := range teams {
@ -136,7 +145,6 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
Primary bool `json:"primary"` Primary bool `json:"primary"`
} }
// if we require an Org or Team, check that first // if we require an Org or Team, check that first
if p.Org != "" { if p.Org != "" {
if p.Team != "" { if p.Team != "" {
@ -153,7 +161,8 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
params := url.Values{ params := url.Values{
"access_token": {access_token}, "access_token": {access_token},
} }
resp, err := http.DefaultClient.Get("https://api.github.com/user/emails?" + params.Encode()) endpoint := "https://api.github.com/user/emails?" + params.Encode()
resp, err := http.DefaultClient.Get(endpoint)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -162,9 +171,12 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
if err != nil { if err != nil {
return "", err return "", err
} }
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
}
if err := json.Unmarshal(body, &emails); err != nil { if err := json.Unmarshal(body, &emails); err != nil {
return "", err return "", fmt.Errorf("%s unmarshaling %s", err, body)
} }
for _, email := range emails { for _, email := range emails {

View File

@ -6,8 +6,7 @@ import (
"net/http" "net/http"
) )
func validateToken(p Provider, access_token string, func validateToken(p Provider, access_token string, header http.Header) bool {
header http.Header) bool {
if access_token == "" || p.Data().ValidateUrl == nil { if access_token == "" || p.Data().ValidateUrl == nil {
return false return false
} }

View File

@ -4,9 +4,11 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings"
) )
func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) { func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
@ -37,6 +39,10 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri
return nil, "", err return nil, "", err
} }
if resp.StatusCode != 200 {
return body, "", fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body)
}
// blindly try json and x-www-form-urlencoded // blindly try json and x-www-form-urlencoded
var jsonResponse struct { var jsonResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
@ -49,3 +55,16 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri
v, err := url.ParseQuery(string(body)) v, err := url.ParseQuery(string(body))
return body, v.Get("access_token"), err return body, v.Get("access_token"), err
} }
func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
params := url.Values{}
params.Add("redirect_uri", redirectURI)
params.Add("approval_prompt", "force")
params.Add("scope", p.Scope)
params.Add("client_id", p.ClientID)
params.Add("response_type", "code")
if strings.HasPrefix(finalRedirect, "/") {
params.Add("state", finalRedirect)
}
return fmt.Sprintf("%s?%s", p.LoginUrl, params.Encode())
}

View File

@ -5,6 +5,7 @@ type Provider interface {
GetEmailAddress(body []byte, access_token string) (string, error) GetEmailAddress(body []byte, access_token string) (string, error)
Redeem(string, string) ([]byte, string, error) Redeem(string, string) ([]byte, string, error)
ValidateToken(access_token string) bool ValidateToken(access_token string) bool
GetLoginURL(redirectURI, finalRedirect string) string
} }
func New(provider string, p *ProviderData) Provider { func New(provider string, p *ProviderData) Provider {