go-bouquins/bouquins/auth.go

149 lines
4.0 KiB
Go
Raw Normal View History

package bouquins
import (
"fmt"
"log"
"math/rand"
"net/http"
"github.com/gorilla/sessions"
"golang.org/x/oauth2"
)
const (
2017-09-08 18:29:01 +00:00
alphanums = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
sessionName = "bouquins"
sessionOAuthState = "oauthState"
sessionOAuthProvider = "provider"
sessionUser = "username"
pProvider = "provider"
)
var (
2017-09-08 18:41:30 +00:00
// Providers contains OAuth2 providers implementations
Providers []OAuth2Provider
)
2017-09-09 07:16:46 +00:00
// LoginModel is login page model
type LoginModel struct {
Model
Providers []OAuth2Provider
}
// NewLoginModel constructor for LoginModel
func (app *Bouquins) NewLoginModel(req *http.Request) *LoginModel {
2019-09-08 09:23:22 +00:00
// TODO filter configured providers
2017-09-09 07:16:46 +00:00
return &LoginModel{*app.NewModel("Authentification", "provider", req), Providers}
}
2017-09-08 18:41:30 +00:00
// OAuth2Provider allows to get a user from an OAuth2 token
2017-09-08 18:29:01 +00:00
type OAuth2Provider interface {
2019-09-08 09:23:22 +00:00
GetUser(app *Bouquins, token *oauth2.Token) (string, error)
2017-09-09 11:27:07 +00:00
Config(conf *Conf) *oauth2.Config
2017-09-08 18:29:01 +00:00
Name() string
2017-09-09 07:16:46 +00:00
Label() string
Icon() string
2017-09-08 18:29:01 +00:00
}
2019-09-08 08:41:10 +00:00
func findProvider(name string) OAuth2Provider {
for _, p := range Providers {
if p.Name() == name {
return p
}
}
return nil
}
// generates a 16 characters long random string
func securedRandString() string {
b := make([]byte, 16)
for i := range b {
b[i] = alphanums[rand.Intn(len(alphanums))]
}
return string(b)
}
2017-09-08 18:41:30 +00:00
// Session returns current session
func (app *Bouquins) Session(req *http.Request) *sessions.Session {
session, _ := app.Cookies.Get(req, sessionName)
return session
}
2017-09-08 18:41:30 +00:00
// Username returns logged in username
func (app *Bouquins) Username(req *http.Request) string {
username := app.Session(req).Values[sessionUser]
if username != nil {
return username.(string)
}
return ""
}
2017-09-08 18:41:30 +00:00
// SessionSet sets a value in session
func (app *Bouquins) SessionSet(name string, value string, res http.ResponseWriter, req *http.Request) {
session := app.Session(req)
session.Values[name] = value
session.Save(req, res)
}
// LoginPage redirects to OAuth login page (github)
func (app *Bouquins) LoginPage(res http.ResponseWriter, req *http.Request) error {
2017-09-08 18:29:01 +00:00
provider := req.URL.Query().Get(pProvider)
oauth := app.OAuthConf[provider]
if oauth != nil {
app.SessionSet(sessionOAuthProvider, provider, res, req)
state := securedRandString()
app.SessionSet(sessionOAuthState, state, res, req)
url := oauth.AuthCodeURL(state)
2017-09-09 09:06:04 +00:00
log.Println("OAuth redirect", url)
2017-09-08 18:29:01 +00:00
http.Redirect(res, req, url, http.StatusTemporaryRedirect)
return nil
}
// choose provider
2017-09-09 07:16:46 +00:00
return app.render(res, tplProvider, app.NewLoginModel(req))
}
// LogoutPage logout connected user
func (app *Bouquins) LogoutPage(res http.ResponseWriter, req *http.Request) error {
app.SessionSet(sessionUser, "", res, req)
return RedirectHome(res, req)
}
// CallbackPage handle OAuth 2 callback
func (app *Bouquins) CallbackPage(res http.ResponseWriter, req *http.Request) error {
savedState := app.Session(req).Values[sessionOAuthState]
2017-09-08 18:29:01 +00:00
providerParam := app.Session(req).Values[sessionOAuthProvider]
if savedState == "" || providerParam == "" {
return fmt.Errorf("missing oauth data")
}
providerName := providerParam.(string)
oauth := app.OAuthConf[providerName]
provider := findProvider(providerName)
if oauth == nil || provider == nil {
return fmt.Errorf("missing oauth configuration")
}
app.SessionSet(sessionOAuthState, "", res, req)
2017-09-08 18:29:01 +00:00
app.SessionSet(sessionOAuthProvider, "", res, req)
state := req.FormValue("state")
if state != savedState {
return fmt.Errorf("invalid oauth state, expected '%s', got '%s'", "state", state)
}
code := req.FormValue("code")
2017-09-08 18:29:01 +00:00
token, err := oauth.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("Code exchange failed with '%s'", err)
}
2019-09-08 09:23:22 +00:00
userEmail, err := provider.GetUser(app, token)
if err != nil {
2017-09-08 18:29:01 +00:00
return err
}
2017-09-09 15:12:37 +00:00
user, err := Account(userEmail)
if err != nil {
log.Println("Error loading user", err)
return fmt.Errorf("Unknown user")
}
2017-09-09 15:12:37 +00:00
app.SessionSet(sessionUser, user.DisplayName, res, req)
log.Println("User logged in", user.DisplayName)
return RedirectHome(res, req)
}