Github provider

This commit is contained in:
Jehiah Czebotar 2015-05-20 23:23:48 -04:00
parent 8471f972e1
commit 37b38dd2f4
17 changed files with 304 additions and 112 deletions

View File

@ -101,6 +101,8 @@ Usage of google_auth_proxy:
-version=false: print version string -version=false: print version string
``` ```
See below for provider specific options
### Environment variables ### Environment variables
The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments. The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments.
@ -173,6 +175,10 @@ directive. Right now this includes:
* `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service * `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service
([GitHub](https://github.com/18F/myusa)) ([GitHub](https://github.com/18F/myusa))
* `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service. * `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service.
* `github` - Via [Github][https://github.com/settings/developers] OAuth App. Also supports restricting via org and team.
-github-org="": restrict logins to members of this organisation
-github-team="": restrict logins to members of this team
## Adding a new Provider ## Adding a new Provider

View File

@ -10,8 +10,7 @@ import (
) )
func Request(req *http.Request) (*simplejson.Json, error) { func Request(req *http.Request) (*simplejson.Json, error) {
httpclient := &http.Client{} resp, err := http.DefaultClient.Do(req)
resp, err := httpclient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -17,6 +17,7 @@ import (
) )
func main() { func main() {
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError) flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError)
googleAppsDomains := StringArray{} googleAppsDomains := StringArray{}
@ -35,6 +36,8 @@ func main() {
flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)") flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)")
flagSet.String("github-org", "", "restrict logins to members of this organisation")
flagSet.String("github-team", "", "restrict logins to members of this team")
flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"")
flagSet.String("client-secret", "", "the OAuth Client Secret") flagSet.String("client-secret", "", "the OAuth Client Secret")
flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)")

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"encoding/base64" "encoding/base64"
@ -17,7 +16,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/bitly/google_auth_proxy/api"
"github.com/bitly/google_auth_proxy/providers" "github.com/bitly/google_auth_proxy/providers"
) )
@ -39,7 +37,6 @@ type OauthProxy struct {
redirectUrl *url.URL // the url to receive requests at redirectUrl *url.URL // the url to receive requests at
provider providers.Provider provider providers.Provider
oauthRedemptionUrl *url.URL // endpoint to redeem the code
oauthLoginUrl *url.URL // to redirect the user to oauthLoginUrl *url.URL // to redirect the user to
oauthValidateUrl *url.URL // to validate the access token oauthValidateUrl *url.URL // to validate the access token
oauthScope string oauthScope string
@ -143,21 +140,20 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
CookieRefresh: opts.CookieRefresh, CookieRefresh: opts.CookieRefresh,
Validator: validator, Validator: validator,
clientID: opts.ClientID, clientID: opts.ClientID,
clientSecret: opts.ClientSecret, clientSecret: opts.ClientSecret,
oauthScope: opts.provider.Data().Scope, oauthScope: opts.provider.Data().Scope,
provider: opts.provider, provider: opts.provider,
oauthRedemptionUrl: opts.provider.Data().RedeemUrl, oauthLoginUrl: opts.provider.Data().LoginUrl,
oauthLoginUrl: opts.provider.Data().LoginUrl, oauthValidateUrl: opts.provider.Data().ValidateUrl,
oauthValidateUrl: opts.provider.Data().ValidateUrl, serveMux: serveMux,
serveMux: serveMux, redirectUrl: redirectUrl,
redirectUrl: redirectUrl, skipAuthRegex: opts.SkipAuthRegex,
skipAuthRegex: opts.SkipAuthRegex, compiledRegex: opts.CompiledRegex,
compiledRegex: opts.CompiledRegex, PassBasicAuth: opts.PassBasicAuth,
PassBasicAuth: opts.PassBasicAuth, PassAccessToken: opts.PassAccessToken,
PassAccessToken: opts.PassAccessToken, AesCipher: aes_cipher,
AesCipher: aes_cipher, templates: loadTemplates(opts.CustomTemplatesDir),
templates: loadTemplates(opts.CustomTemplatesDir),
} }
} }
@ -200,29 +196,13 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) {
if code == "" { if code == "" {
return "", "", errors.New("missing code") return "", "", errors.New("missing code")
} }
params := url.Values{} redirectUri := p.GetRedirectUrl(host)
params.Add("redirect_uri", p.GetRedirectUrl(host)) body, access_token, err := p.provider.Redeem(redirectUri, code)
params.Add("client_id", p.clientID)
params.Add("client_secret", p.clientSecret)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
req, err := http.NewRequest("POST", p.oauthRedemptionUrl.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
log.Printf("failed building request %s", err.Error())
return "", "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
json, err := api.Request(req)
if err != nil {
log.Printf("failed making request %s", err)
return "", "", err
}
access_token, err := json.Get("access_token").String()
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
email, err := p.provider.GetEmailAddress(json, access_token) email, err := p.provider.GetEmailAddress(body, access_token)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }

View File

@ -1,10 +1,10 @@
package main package main
import ( import (
"github.com/bitly/go-simplejson"
"github.com/bitly/google_auth_proxy/providers" "github.com/bitly/google_auth_proxy/providers"
"github.com/bmizerany/assert" "github.com/bmizerany/assert"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -15,6 +15,11 @@ import (
"time" "time"
) )
func init() {
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
}
func TestNewReverseProxy(t *testing.T) { func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(200)
@ -89,8 +94,7 @@ type TestProvider struct {
ValidToken bool ValidToken bool
} }
func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
unused_access_token string) (string, error) {
return tp.EmailAddress, nil return tp.EmailAddress, nil
} }
@ -113,16 +117,15 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
t.provider_server = httptest.NewServer( t.provider_server = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r)
url := r.URL url := r.URL
payload := "" payload := ""
switch url.Path { switch url.Path {
case "/oauth/token": case "/oauth/token":
payload = `{"access_token": "my_auth_token"}` payload = `{"access_token": "my_auth_token"}`
default: default:
token_header := r.Header["X-Forwarded-Access-Token"] payload = r.Header.Get("X-Forwarded-Access-Token")
if len(token_header) != 0 { if payload == "" {
payload = token_header[0]
} else {
payload = "No access token found." payload = "No access token found."
} }
} }
@ -189,8 +192,7 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
return rw.Code, rw.HeaderMap["Set-Cookie"][0] return rw.Code, rw.HeaderMap["Set-Cookie"][0]
} }
func (pat_test *PassAccessTokenTest) getRootEndpoint( func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
cookie string) (http_code int, access_token string) {
cookie_key := pat_test.proxy.CookieKey cookie_key := pat_test.proxy.CookieKey
var value string var value string
key_prefix := cookie_key + "=" key_prefix := cookie_key + "="

View File

@ -19,6 +19,8 @@ type Options struct {
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
GoogleAppsDomains []string `flag:"google-apps-domain" cfg:"google_apps_domains"` GoogleAppsDomains []string `flag:"google-apps-domain" cfg:"google_apps_domains"`
GitHubOrg string `flag:"github-org" cfg:"github_org"`
GitHubTeam string `flag:"github-team" cfg:"github_team"`
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"`
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"`
@ -153,11 +155,16 @@ func (o *Options) Validate() error {
} }
func parseProviderInfo(o *Options, msgs []string) []string { func parseProviderInfo(o *Options, msgs []string) []string {
p := &providers.ProviderData{Scope: o.Scope} p := &providers.ProviderData{Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret}
p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs) p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs)
p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs) p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs)
p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs) p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs)
p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs) p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs)
o.provider = providers.New(o.Provider, p) o.provider = providers.New(o.Provider, p)
switch p := o.provider.(type) {
case *providers.GitHubProvider:
p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam)
}
return msgs return msgs
} }

136
providers/github.go Normal file
View File

@ -0,0 +1,136 @@
package providers
import (
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
)
type GitHubProvider struct {
*ProviderData
Org string
Team string
}
func NewGitHubProvider(p *ProviderData) *GitHubProvider {
p.ProviderName = "GitHub"
if p.LoginUrl.String() == "" {
p.LoginUrl = &url.URL{
Scheme: "https",
Host: "github.com",
Path: "/login/oauth/authorize",
}
}
if p.RedeemUrl.String() == "" {
p.RedeemUrl = &url.URL{
Scheme: "https",
Host: "github.com",
Path: "/login/oauth/access_token",
}
}
if p.ValidateUrl.String() == "" {
p.ValidateUrl = &url.URL{
Scheme: "https",
Host: "api.github.com",
Path: "/user/emails",
}
}
if p.Scope == "" {
p.Scope = "user:email"
}
return &GitHubProvider{ProviderData: p}
}
func (p *GitHubProvider) SetOrgTeam(org, team string) {
p.Org = org
p.Team = team
if org != "" || team != "" {
p.Scope += " read:org"
}
}
func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
var teams []struct {
Name string `json:"name"`
Slug string `json:"slug"`
Org struct {
Login string `json:"login"`
} `json:"organization"`
}
params := url.Values{
"access_token": {accessToken},
}
req, _ := http.NewRequest("GET", "https://api.github.com/user/teams?"+params.Encode(), nil)
req.Header.Set("Accept", "application/vnd.github.moondragon+json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if err := json.Unmarshal(body, &teams); err != nil {
return false, err
}
for _, team := range teams {
if p.Org == team.Org.Login {
if p.Team == "" || p.Team == team.Slug {
return true, nil
}
}
}
return false, nil
}
func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
var emails []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
params := url.Values{
"access_token": {access_token},
}
// if we require an Org or Team, check that first
if p.Org != "" || p.Team != "" {
if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok {
return "", err
}
}
resp, err := http.DefaultClient.Get("https://api.github.com/user/emails?" + params.Encode())
if err != nil {
return "", err
}
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return "", err
}
if err := json.Unmarshal(body, &emails); err != nil {
return "", err
}
for _, email := range emails {
if email.Primary {
return email.Email, nil
}
}
return "", nil
}
func (p *GitHubProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
}

View File

@ -2,10 +2,10 @@ package providers
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors"
"net/url" "net/url"
"strings" "strings"
"github.com/bitly/go-simplejson"
) )
type GoogleProvider struct { type GoogleProvider struct {
@ -35,28 +35,34 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
return &GoogleProvider{ProviderData: p} return &GoogleProvider{ProviderData: p}
} }
func (s *GoogleProvider) GetEmailAddress(auth_response *simplejson.Json, func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
unused_access_token string) (string, error) { var response struct {
idToken, err := auth_response.Get("id_token").String() IdToken string `json:"id_token"`
if err != nil { }
if err := json.Unmarshal(body, &response); err != nil {
return "", err return "", err
} }
// id_token is a base64 encode ID token payload // id_token is a base64 encode ID token payload
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
jwt := strings.Split(idToken, ".") jwt := strings.Split(response.IdToken, ".")
b, err := jwtDecodeSegment(jwt[1]) b, err := jwtDecodeSegment(jwt[1])
if err != nil { if err != nil {
return "", err return "", err
} }
data, err := simplejson.NewJson(b)
var email struct {
Email string `json:"email"`
}
err = json.Unmarshal(b, &email)
if err != nil { if err != nil {
return "", err return "", err
} }
email, err := data.Get("email").String() if email.Email == "" {
if err != nil { return "", errors.New("missing email")
return "", err
} }
return email, nil return email.Email, nil
} }
func jwtDecodeSegment(seg string) ([]byte, error) { func jwtDecodeSegment(seg string) ([]byte, error) {

View File

@ -2,7 +2,7 @@ package providers
import ( import (
"encoding/base64" "encoding/base64"
"github.com/bitly/go-simplejson" "encoding/json"
"github.com/bmizerany/assert" "github.com/bmizerany/assert"
"net/url" "net/url"
"testing" "testing"
@ -68,39 +68,61 @@ func TestGoogleProviderOverrides(t *testing.T) {
func TestGoogleProviderGetEmailAddress(t *testing.T) { func TestGoogleProviderGetEmailAddress(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
j := simplejson.New() body, err := json.Marshal(
j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( struct {
[]byte("{\"email\": \"michael.bland@gsa.gov\"}"))) IdToken string `json:"id_token"`
email, err := p.GetEmailAddress(j, "ignored access_token") }{
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)),
},
)
assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token")
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
} }
func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
j := simplejson.New() body, err := json.Marshal(
j.Set("id_token", "ignored prefix.{\"email\": \"michael.bland@gsa.gov\"}") struct {
email, err := p.GetEmailAddress(j, "ignored access_token") IdToken string `json:"id_token"`
}{
IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
},
)
assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token")
assert.Equal(t, "", email) assert.Equal(t, "", email)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
j := simplejson.New()
j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( body, err := json.Marshal(
[]byte("{email: michael.bland@gsa.gov}"))) struct {
email, err := p.GetEmailAddress(j, "ignored access_token") IdToken string `json:"id_token"`
}{
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
},
)
assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token")
assert.Equal(t, "", email) assert.Equal(t, "", email)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
j := simplejson.New() body, err := json.Marshal(
j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( struct {
[]byte("{\"not_email\": \"missing!\"}"))) IdToken string `json:"id_token"`
email, err := p.GetEmailAddress(j, "ignored access_token") }{
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
},
)
assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token")
assert.Equal(t, "", email) assert.Equal(t, "", email)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }

View File

@ -1,7 +1,6 @@
package providers package providers
import ( import (
"github.com/bitly/go-simplejson"
"github.com/bmizerany/assert" "github.com/bmizerany/assert"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -13,9 +12,7 @@ type ValidateTokenTestProvider struct {
*ProviderData *ProviderData
} }
func (tp *ValidateTokenTestProvider) GetEmailAddress( func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
unused_auth_response *simplejson.Json,
unused_access_token string) (string, error) {
return "", nil return "", nil
} }

View File

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/go-simplejson"
"github.com/bitly/google_auth_proxy/api" "github.com/bitly/google_auth_proxy/api"
) )
@ -50,8 +49,7 @@ func getLinkedInHeader(access_token string) http.Header {
return header return header
} }
func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json, func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
access_token string) (string, error) {
if access_token == "" { if access_token == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }

View File

@ -1,7 +1,6 @@
package providers package providers
import ( import (
"github.com/bitly/go-simplejson"
"github.com/bmizerany/assert" "github.com/bmizerany/assert"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -97,9 +96,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(b_url.Host)
unused_auth_response := simplejson.New()
email, err := p.GetEmailAddress(unused_auth_response, email, err := p.GetEmailAddress([]byte{},
"imaginary_access_token") "imaginary_access_token")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email) assert.Equal(t, "user@linkedin.com", email)
@ -111,13 +109,11 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(b_url.Host)
unused_auth_response := simplejson.New()
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
email, err := p.GetEmailAddress(unused_auth_response, email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token")
"unexpected_access_token")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -128,10 +124,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(b_url.Host)
unused_auth_response := simplejson.New()
email, err := p.GetEmailAddress(unused_auth_response, email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
"imaginary_access_token")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -5,7 +5,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/go-simplejson"
"github.com/bitly/google_auth_proxy/api" "github.com/bitly/google_auth_proxy/api"
) )
@ -43,8 +42,7 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider {
return &MyUsaProvider{ProviderData: p} return &MyUsaProvider{ProviderData: p}
} }
func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json, func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
access_token string) (string, error) {
req, err := http.NewRequest("GET", req, err := http.NewRequest("GET",
p.ProfileUrl.String()+"?access_token="+access_token, nil) p.ProfileUrl.String()+"?access_token="+access_token, nil)
if err != nil { if err != nil {

View File

@ -1,7 +1,6 @@
package providers package providers
import ( import (
"github.com/bitly/go-simplejson"
"github.com/bmizerany/assert" "github.com/bmizerany/assert"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -102,10 +101,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testMyUsaProvider(b_url.Host) p := testMyUsaProvider(b_url.Host)
unused_auth_response := simplejson.New()
email, err := p.GetEmailAddress(unused_auth_response, email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
"imaginary_access_token")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -118,13 +115,11 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testMyUsaProvider(b_url.Host) p := testMyUsaProvider(b_url.Host)
unused_auth_response := simplejson.New()
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
email, err := p.GetEmailAddress(unused_auth_response, email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token")
"unexpected_access_token")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -135,10 +130,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testMyUsaProvider(b_url.Host) p := testMyUsaProvider(b_url.Host)
unused_auth_response := simplejson.New()
email, err := p.GetEmailAddress(unused_auth_response, email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
"imaginary_access_token")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -6,6 +6,8 @@ import (
type ProviderData struct { type ProviderData struct {
ProviderName string ProviderName string
ClientID string
ClientSecret string
LoginUrl *url.URL LoginUrl *url.URL
RedeemUrl *url.URL RedeemUrl *url.URL
ProfileUrl *url.URL ProfileUrl *url.URL

View File

@ -0,0 +1,51 @@
package providers
import (
"bytes"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/url"
)
func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
if code == "" {
err = errors.New("missing code")
return
}
params := url.Values{}
params.Add("redirect_uri", redirectUrl)
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return nil, "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, "", err
}
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, "", err
}
// blindly try json and x-www-form-urlencoded
var jsonResponse struct {
AccessToken string `json:"access_token"`
}
err = json.Unmarshal(body, &jsonResponse)
if err == nil {
return body, jsonResponse.AccessToken, nil
}
v, err := url.ParseQuery(string(body))
return body, v.Get("access_token"), err
}

View File

@ -1,13 +1,9 @@
package providers package providers
import (
"github.com/bitly/go-simplejson"
)
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(auth_response *simplejson.Json, GetEmailAddress(body []byte, access_token string) (string, error)
access_token string) (string, error) Redeem(string, string) ([]byte, string, error)
ValidateToken(access_token string) bool ValidateToken(access_token string) bool
} }
@ -17,6 +13,8 @@ func New(provider string, p *ProviderData) Provider {
return NewMyUsaProvider(p) return NewMyUsaProvider(p)
case "linkedin": case "linkedin":
return NewLinkedInProvider(p) return NewLinkedInProvider(p)
case "github":
return NewGitHubProvider(p)
default: default:
return NewGoogleProvider(p) return NewGoogleProvider(p)
} }