From 37b38dd2f44023d25687b8cf9ec139f6d85ce012 Mon Sep 17 00:00:00 2001 From: Jehiah Czebotar Date: Wed, 20 May 2015 23:23:48 -0400 Subject: [PATCH] Github provider --- README.md | 6 ++ api/api.go | 3 +- main.go | 3 + oauthproxy.go | 54 ++++--------- oauthproxy_test.go | 20 ++--- options.go | 9 ++- providers/github.go | 136 ++++++++++++++++++++++++++++++++ providers/google.go | 30 ++++--- providers/google_test.go | 54 +++++++++---- providers/internal_util_test.go | 5 +- providers/linkedin.go | 4 +- providers/linkedin_test.go | 12 +-- providers/myusa.go | 4 +- providers/myusa_test.go | 13 +-- providers/provider_data.go | 2 + providers/provider_default.go | 51 ++++++++++++ providers/providers.go | 10 +-- 17 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 providers/github.go create mode 100644 providers/provider_default.go diff --git a/README.md b/README.md index 2c05885..0fc96c6 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,8 @@ Usage of google_auth_proxy: -version=false: print version string ``` +See below for provider specific options + ### Environment variables The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments. @@ -173,6 +175,10 @@ directive. Right now this includes: * `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service ([GitHub](https://github.com/18F/myusa)) * `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service. +* `github` - Via [Github][https://github.com/settings/developers] OAuth App. Also supports restricting via org and team. + + -github-org="": restrict logins to members of this organisation + -github-team="": restrict logins to members of this team ## Adding a new Provider diff --git a/api/api.go b/api/api.go index d2de0ce..19e75e9 100644 --- a/api/api.go +++ b/api/api.go @@ -10,8 +10,7 @@ import ( ) func Request(req *http.Request) (*simplejson.Json, error) { - httpclient := &http.Client{} - resp, err := httpclient.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/main.go b/main.go index 9880342..ca6d0c7 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ import ( ) func main() { + log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError) googleAppsDomains := StringArray{} @@ -35,6 +36,8 @@ func main() { flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)") + flagSet.String("github-org", "", "restrict logins to members of this organisation") + flagSet.String("github-team", "", "restrict logins to members of this team") flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") flagSet.String("client-secret", "", "the OAuth Client Secret") flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") diff --git a/oauthproxy.go b/oauthproxy.go index 7738c35..3ecdc98 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/aes" "crypto/cipher" "encoding/base64" @@ -17,7 +16,6 @@ import ( "strings" "time" - "github.com/bitly/google_auth_proxy/api" "github.com/bitly/google_auth_proxy/providers" ) @@ -39,7 +37,6 @@ type OauthProxy struct { redirectUrl *url.URL // the url to receive requests at provider providers.Provider - oauthRedemptionUrl *url.URL // endpoint to redeem the code oauthLoginUrl *url.URL // to redirect the user to oauthValidateUrl *url.URL // to validate the access token oauthScope string @@ -143,21 +140,20 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { CookieRefresh: opts.CookieRefresh, Validator: validator, - clientID: opts.ClientID, - clientSecret: opts.ClientSecret, - oauthScope: opts.provider.Data().Scope, - provider: opts.provider, - oauthRedemptionUrl: opts.provider.Data().RedeemUrl, - oauthLoginUrl: opts.provider.Data().LoginUrl, - oauthValidateUrl: opts.provider.Data().ValidateUrl, - serveMux: serveMux, - redirectUrl: redirectUrl, - skipAuthRegex: opts.SkipAuthRegex, - compiledRegex: opts.CompiledRegex, - PassBasicAuth: opts.PassBasicAuth, - PassAccessToken: opts.PassAccessToken, - AesCipher: aes_cipher, - templates: loadTemplates(opts.CustomTemplatesDir), + clientID: opts.ClientID, + clientSecret: opts.ClientSecret, + oauthScope: opts.provider.Data().Scope, + provider: opts.provider, + oauthLoginUrl: opts.provider.Data().LoginUrl, + oauthValidateUrl: opts.provider.Data().ValidateUrl, + serveMux: serveMux, + redirectUrl: redirectUrl, + skipAuthRegex: opts.SkipAuthRegex, + compiledRegex: opts.CompiledRegex, + PassBasicAuth: opts.PassBasicAuth, + PassAccessToken: opts.PassAccessToken, + AesCipher: aes_cipher, + templates: loadTemplates(opts.CustomTemplatesDir), } } @@ -200,29 +196,13 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) { if code == "" { return "", "", errors.New("missing code") } - params := url.Values{} - params.Add("redirect_uri", p.GetRedirectUrl(host)) - params.Add("client_id", p.clientID) - params.Add("client_secret", p.clientSecret) - params.Add("code", code) - params.Add("grant_type", "authorization_code") - req, err := http.NewRequest("POST", p.oauthRedemptionUrl.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - log.Printf("failed building request %s", err.Error()) - return "", "", err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - json, err := api.Request(req) - if err != nil { - log.Printf("failed making request %s", err) - return "", "", err - } - access_token, err := json.Get("access_token").String() + redirectUri := p.GetRedirectUrl(host) + body, access_token, err := p.provider.Redeem(redirectUri, code) if err != nil { return "", "", err } - email, err := p.provider.GetEmailAddress(json, access_token) + email, err := p.provider.GetEmailAddress(body, access_token) if err != nil { return "", "", err } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 409f1d5..5dcb2ec 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1,10 +1,10 @@ package main import ( - "github.com/bitly/go-simplejson" "github.com/bitly/google_auth_proxy/providers" "github.com/bmizerany/assert" "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -15,6 +15,11 @@ import ( "time" ) +func init() { + log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) + +} + func TestNewReverseProxy(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -89,8 +94,7 @@ type TestProvider struct { ValidToken bool } -func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, - unused_access_token string) (string, error) { +func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { return tp.EmailAddress, nil } @@ -113,16 +117,15 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes t.provider_server = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("%#v", r) url := r.URL payload := "" switch url.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: - token_header := r.Header["X-Forwarded-Access-Token"] - if len(token_header) != 0 { - payload = token_header[0] - } else { + payload = r.Header.Get("X-Forwarded-Access-Token") + if payload == "" { payload = "No access token found." } } @@ -189,8 +192,7 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, return rw.Code, rw.HeaderMap["Set-Cookie"][0] } -func (pat_test *PassAccessTokenTest) getRootEndpoint( - cookie string) (http_code int, access_token string) { +func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { cookie_key := pat_test.proxy.CookieKey var value string key_prefix := cookie_key + "=" diff --git a/options.go b/options.go index 3767546..fcf725d 100644 --- a/options.go +++ b/options.go @@ -19,6 +19,8 @@ type Options struct { AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` GoogleAppsDomains []string `flag:"google-apps-domain" cfg:"google_apps_domains"` + GitHubOrg string `flag:"github-org" cfg:"github_org"` + GitHubTeam string `flag:"github-team" cfg:"github_team"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` @@ -153,11 +155,16 @@ func (o *Options) Validate() error { } func parseProviderInfo(o *Options, msgs []string) []string { - p := &providers.ProviderData{Scope: o.Scope} + p := &providers.ProviderData{Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret} p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs) p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs) p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs) p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs) + o.provider = providers.New(o.Provider, p) + switch p := o.provider.(type) { + case *providers.GitHubProvider: + p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) + } return msgs } diff --git a/providers/github.go b/providers/github.go new file mode 100644 index 0000000..4aaa2fe --- /dev/null +++ b/providers/github.go @@ -0,0 +1,136 @@ +package providers + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "net/url" +) + +type GitHubProvider struct { + *ProviderData + Org string + Team string +} + +func NewGitHubProvider(p *ProviderData) *GitHubProvider { + p.ProviderName = "GitHub" + if p.LoginUrl.String() == "" { + p.LoginUrl = &url.URL{ + Scheme: "https", + Host: "github.com", + Path: "/login/oauth/authorize", + } + } + if p.RedeemUrl.String() == "" { + p.RedeemUrl = &url.URL{ + Scheme: "https", + Host: "github.com", + Path: "/login/oauth/access_token", + } + } + if p.ValidateUrl.String() == "" { + p.ValidateUrl = &url.URL{ + Scheme: "https", + Host: "api.github.com", + Path: "/user/emails", + } + } + if p.Scope == "" { + p.Scope = "user:email" + } + return &GitHubProvider{ProviderData: p} +} +func (p *GitHubProvider) SetOrgTeam(org, team string) { + p.Org = org + p.Team = team + if org != "" || team != "" { + p.Scope += " read:org" + } +} + +func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { + + var teams []struct { + Name string `json:"name"` + Slug string `json:"slug"` + Org struct { + Login string `json:"login"` + } `json:"organization"` + } + + params := url.Values{ + "access_token": {accessToken}, + } + + req, _ := http.NewRequest("GET", "https://api.github.com/user/teams?"+params.Encode(), nil) + req.Header.Set("Accept", "application/vnd.github.moondragon+json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false, err + } + + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return false, err + } + + if err := json.Unmarshal(body, &teams); err != nil { + return false, err + } + + for _, team := range teams { + if p.Org == team.Org.Login { + if p.Team == "" || p.Team == team.Slug { + return true, nil + } + } + } + return false, nil +} + +func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) { + + var emails []struct { + Email string `json:"email"` + Primary bool `json:"primary"` + } + + params := url.Values{ + "access_token": {access_token}, + } + + // if we require an Org or Team, check that first + if p.Org != "" || p.Team != "" { + if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok { + return "", err + } + } + + resp, err := http.DefaultClient.Get("https://api.github.com/user/emails?" + params.Encode()) + if err != nil { + return "", err + } + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return "", err + } + + if err := json.Unmarshal(body, &emails); err != nil { + return "", err + } + + for _, email := range emails { + if email.Primary { + return email.Email, nil + } + } + + return "", nil +} + +func (p *GitHubProvider) ValidateToken(access_token string) bool { + return validateToken(p, access_token, nil) +} diff --git a/providers/google.go b/providers/google.go index aa162d9..265eaa6 100644 --- a/providers/google.go +++ b/providers/google.go @@ -2,10 +2,10 @@ package providers import ( "encoding/base64" + "encoding/json" + "errors" "net/url" "strings" - - "github.com/bitly/go-simplejson" ) type GoogleProvider struct { @@ -35,28 +35,34 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { return &GoogleProvider{ProviderData: p} } -func (s *GoogleProvider) GetEmailAddress(auth_response *simplejson.Json, - unused_access_token string) (string, error) { - idToken, err := auth_response.Get("id_token").String() - if err != nil { +func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) { + var response struct { + IdToken string `json:"id_token"` + } + + if err := json.Unmarshal(body, &response); err != nil { return "", err } + // id_token is a base64 encode ID token payload // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo - jwt := strings.Split(idToken, ".") + jwt := strings.Split(response.IdToken, ".") b, err := jwtDecodeSegment(jwt[1]) if err != nil { return "", err } - data, err := simplejson.NewJson(b) + + var email struct { + Email string `json:"email"` + } + err = json.Unmarshal(b, &email) if err != nil { return "", err } - email, err := data.Get("email").String() - if err != nil { - return "", err + if email.Email == "" { + return "", errors.New("missing email") } - return email, nil + return email.Email, nil } func jwtDecodeSegment(seg string) ([]byte, error) { diff --git a/providers/google_test.go b/providers/google_test.go index 532199c..7456d3b 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -2,7 +2,7 @@ package providers import ( "encoding/base64" - "github.com/bitly/go-simplejson" + "encoding/json" "github.com/bmizerany/assert" "net/url" "testing" @@ -68,39 +68,61 @@ func TestGoogleProviderOverrides(t *testing.T) { func TestGoogleProviderGetEmailAddress(t *testing.T) { p := newGoogleProvider() - j := simplejson.New() - j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( - []byte("{\"email\": \"michael.bland@gsa.gov\"}"))) - email, err := p.GetEmailAddress(j, "ignored access_token") + body, err := json.Marshal( + struct { + IdToken string `json:"id_token"` + }{ + IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)), + }, + ) + assert.Equal(t, nil, err) + email, err := p.GetEmailAddress(body, "ignored access_token") assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, nil, err) } func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p := newGoogleProvider() - j := simplejson.New() - j.Set("id_token", "ignored prefix.{\"email\": \"michael.bland@gsa.gov\"}") - email, err := p.GetEmailAddress(j, "ignored access_token") + body, err := json.Marshal( + struct { + IdToken string `json:"id_token"` + }{ + IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, + }, + ) + assert.Equal(t, nil, err) + email, err := p.GetEmailAddress(body, "ignored access_token") assert.Equal(t, "", email) assert.NotEqual(t, nil, err) } func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { p := newGoogleProvider() - j := simplejson.New() - j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( - []byte("{email: michael.bland@gsa.gov}"))) - email, err := p.GetEmailAddress(j, "ignored access_token") + + body, err := json.Marshal( + struct { + IdToken string `json:"id_token"` + }{ + IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), + }, + ) + assert.Equal(t, nil, err) + email, err := p.GetEmailAddress(body, "ignored access_token") assert.Equal(t, "", email) assert.NotEqual(t, nil, err) } func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { p := newGoogleProvider() - j := simplejson.New() - j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( - []byte("{\"not_email\": \"missing!\"}"))) - email, err := p.GetEmailAddress(j, "ignored access_token") + body, err := json.Marshal( + struct { + IdToken string `json:"id_token"` + }{ + IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), + }, + ) + assert.Equal(t, nil, err) + email, err := p.GetEmailAddress(body, "ignored access_token") assert.Equal(t, "", email) assert.NotEqual(t, nil, err) } diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 39b8ab3..36a1d37 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -1,7 +1,6 @@ package providers import ( - "github.com/bitly/go-simplejson" "github.com/bmizerany/assert" "net/http" "net/http/httptest" @@ -13,9 +12,7 @@ type ValidateTokenTestProvider struct { *ProviderData } -func (tp *ValidateTokenTestProvider) GetEmailAddress( - unused_auth_response *simplejson.Json, - unused_access_token string) (string, error) { +func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { return "", nil } diff --git a/providers/linkedin.go b/providers/linkedin.go index 4eea3f7..ae43d0a 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" - "github.com/bitly/go-simplejson" "github.com/bitly/google_auth_proxy/api" ) @@ -50,8 +49,7 @@ func getLinkedInHeader(access_token string) http.Header { return header } -func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json, - access_token string) (string, error) { +func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) { if access_token == "" { return "", errors.New("missing access token") } diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index be8d05e..08b3e47 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -1,7 +1,6 @@ package providers import ( - "github.com/bitly/go-simplejson" "github.com/bmizerany/assert" "net/http" "net/http/httptest" @@ -97,9 +96,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) - unused_auth_response := simplejson.New() - email, err := p.GetEmailAddress(unused_auth_response, + email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") assert.Equal(t, nil, err) assert.Equal(t, "user@linkedin.com", email) @@ -111,13 +109,11 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) - unused_auth_response := simplejson.New() // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - email, err := p.GetEmailAddress(unused_auth_response, - "unexpected_access_token") + email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -128,10 +124,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) - unused_auth_response := simplejson.New() - email, err := p.GetEmailAddress(unused_auth_response, - "imaginary_access_token") + email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/myusa.go b/providers/myusa.go index 05ba33f..83fc37f 100644 --- a/providers/myusa.go +++ b/providers/myusa.go @@ -5,7 +5,6 @@ import ( "net/http" "net/url" - "github.com/bitly/go-simplejson" "github.com/bitly/google_auth_proxy/api" ) @@ -43,8 +42,7 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider { return &MyUsaProvider{ProviderData: p} } -func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json, - access_token string) (string, error) { +func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) { req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?access_token="+access_token, nil) if err != nil { diff --git a/providers/myusa_test.go b/providers/myusa_test.go index 20df092..32e8520 100644 --- a/providers/myusa_test.go +++ b/providers/myusa_test.go @@ -1,7 +1,6 @@ package providers import ( - "github.com/bitly/go-simplejson" "github.com/bmizerany/assert" "net/http" "net/http/httptest" @@ -102,10 +101,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testMyUsaProvider(b_url.Host) - unused_auth_response := simplejson.New() - email, err := p.GetEmailAddress(unused_auth_response, - "imaginary_access_token") + email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -118,13 +115,11 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testMyUsaProvider(b_url.Host) - unused_auth_response := simplejson.New() // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - email, err := p.GetEmailAddress(unused_auth_response, - "unexpected_access_token") + email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -135,10 +130,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testMyUsaProvider(b_url.Host) - unused_auth_response := simplejson.New() - email, err := p.GetEmailAddress(unused_auth_response, - "imaginary_access_token") + email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/provider_data.go b/providers/provider_data.go index 097f065..40cda04 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -6,6 +6,8 @@ import ( type ProviderData struct { ProviderName string + ClientID string + ClientSecret string LoginUrl *url.URL RedeemUrl *url.URL ProfileUrl *url.URL diff --git a/providers/provider_default.go b/providers/provider_default.go new file mode 100644 index 0000000..d962fd9 --- /dev/null +++ b/providers/provider_default.go @@ -0,0 +1,51 @@ +package providers + +import ( + "bytes" + "encoding/json" + "errors" + "io/ioutil" + "net/http" + "net/url" +) + +func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) { + if code == "" { + err = errors.New("missing code") + return + } + + params := url.Values{} + params.Add("redirect_uri", redirectUrl) + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("code", code) + params.Add("grant_type", "authorization_code") + req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return nil, "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, "", err + } + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, "", err + } + + // blindly try json and x-www-form-urlencoded + var jsonResponse struct { + AccessToken string `json:"access_token"` + } + err = json.Unmarshal(body, &jsonResponse) + if err == nil { + return body, jsonResponse.AccessToken, nil + } + + v, err := url.ParseQuery(string(body)) + return body, v.Get("access_token"), err +} diff --git a/providers/providers.go b/providers/providers.go index 8f5091a..6c7f592 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -1,13 +1,9 @@ package providers -import ( - "github.com/bitly/go-simplejson" -) - type Provider interface { Data() *ProviderData - GetEmailAddress(auth_response *simplejson.Json, - access_token string) (string, error) + GetEmailAddress(body []byte, access_token string) (string, error) + Redeem(string, string) ([]byte, string, error) ValidateToken(access_token string) bool } @@ -17,6 +13,8 @@ func New(provider string, p *ProviderData) Provider { return NewMyUsaProvider(p) case "linkedin": return NewLinkedInProvider(p) + case "github": + return NewGitHubProvider(p) default: return NewGoogleProvider(p) }