diff --git a/providers/oidc.go b/providers/oidc.go index 364879c..36b87f6 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -38,7 +38,61 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er if err != nil { return nil, fmt.Errorf("token exchange: %v", err) } + s, err = p.createSessionState(ctx, token) + if err != nil { + return nil, fmt.Errorf("unable to update session: %v", err) + } + return +} +// RefreshSessionIfNeeded checks if the session has expired and uses the +// RefreshToken to fetch a new ID token if required +func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { + if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + + origExpiration := s.ExpiresOn + + err := p.redeemRefreshToken(s) + if err != nil { + return false, fmt.Errorf("unable to redeem refresh token: %v", err) + } + + fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) + return true, nil +} + +func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + } + ctx := context.Background() + t := &oauth2.Token{ + RefreshToken: s.RefreshToken, + Expiry: time.Now().Add(-time.Hour), + } + token, err := c.TokenSource(ctx, t).Token() + if err != nil { + return fmt.Errorf("failed to get token: %v", err) + } + newSession, err := p.createSessionState(ctx, token) + if err != nil { + return fmt.Errorf("unable to update session: %v", err) + } + s.AccessToken = newSession.AccessToken + s.IDToken = newSession.IDToken + s.RefreshToken = newSession.RefreshToken + s.ExpiresOn = newSession.ExpiresOn + s.Email = newSession.Email + return +} + +func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("token response did not contain an id_token") @@ -66,29 +120,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - s = &SessionState{ + return &SessionState{ AccessToken: token.AccessToken, IDToken: rawIDToken, RefreshToken: token.RefreshToken, ExpiresOn: token.Expiry, Email: claims.Email, - } - - return -} - -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -// -// WARNGING: This implementation is broken and does not check with the upstream -// OIDC provider before refreshing the session -func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { - return false, nil - } - - origExpiration := s.ExpiresOn - s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second) - fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) - return false, nil + }, nil }