Merge pull request #7 from pusher/migration

Migration from Bitly to Pusher
This commit is contained in:
Joel Speed 2019-01-14 09:54:05 +00:00 committed by GitHub
commit e1f45dd941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 958 additions and 444 deletions

3
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,3 @@
# Default owner should be a Pusher cloud-team member unless overridden by later
# rules in this file
* @pusher/cloud-team

37
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,37 @@
<!--- Provide a general summary of the issue in the Title above -->
## Expected Behavior
<!--- If you're describing a bug, tell us what should happen -->
<!--- If you're suggesting a change/improvement, tell us how it should work -->
## Current Behavior
<!--- If describing a bug, tell us what happens instead of the expected behavior -->
<!--- If suggesting a change/improvement, explain the difference from current behavior -->
## Possible Solution
<!--- Not obligatory, but suggest a fix/reason for the bug, -->
<!--- or ideas how to implement the addition or change -->
## Steps to Reproduce (for bugs)
<!--- Provide a link to a live example, or an unambiguous set of steps to -->
<!--- reproduce this bug. Include code to reproduce, if relevant -->
1. <!--- Step 1 --->
2. <!--- Step 2 --->
3. <!--- Step 3 --->
4. <!--- Step 4 --->
## Context
<!--- How has this issue affected you? What are you trying to accomplish? -->
<!--- Providing context helps us come up with a solution that is most useful in the real world -->
## Your Environment
<!--- Include as many relevant details about the environment you experienced the bug in -->
- Version used:

25
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,25 @@
<!--- Provide a general summary of your changes in the Title above -->
## Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
## How Has This Been Tested?
<!--- Please describe in detail how you tested your changes. -->
<!--- Include details of your testing environment, and the tests you ran to -->
<!--- see how your change affects other areas of the code, etc. -->
## Checklist:
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
- [ ] My change requires a change to the documentation or CHANGELOG.
- [ ] I have updated the documentation/CHANGELOG accordingly.
- [ ] I have created a feature (non-master) branch for my PR.

2
.gitignore vendored
View File

@ -3,7 +3,7 @@ vendor
dist dist
.godeps .godeps
*.exe *.exe
.env
# Go.gitignore # Go.gitignore
# Compiled Object files, Static and Dynamic libs (Shared Objects) # Compiled Object files, Static and Dynamic libs (Shared Objects)

View File

@ -1,12 +1,16 @@
language: go language: go
go: go:
- 1.8.x
- 1.9.x - 1.9.x
script: - 1.10.x
install:
# Fetch dependencies
- wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64 - wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64
- chmod +x dep - chmod +x dep
- ./dep ensure - mv dep $GOPATH/bin/dep
- ./test.sh script:
- ./configure
# Run tests
- make test
sudo: false sudo: false
notifications: notifications:
email: false email: false

23
CHANGELOG.md Normal file
View File

@ -0,0 +1,23 @@
# Vx.x.x (Pre-release)
## Changes since v2.2:
- Move automated build to debian base image
- Add Makefile
- Update CI to run `make test`
- Update Dockerfile to use `make clean oauth2_proxy`
- Update `VERSION` parameter to be set by `ldflags` from Git Status
- Remove lint and test scripts
- Remove Go v1.8.x from Travis CI testing
- Add CODEOWNERS file
- Add CONTRIBUTING guide
- Add Issue and Pull Request templates
- Add Dockerfile
- Fix fsnotify import
- Update README to reflect new repository ownership
- Update CI scripts to separate linting and testing
- Now using `gometalinter` for linting
- Move Go import path from `github.com/bitly/oauth2_proxy` to `github.com/pusher/oauth2_proxy`
- Repository forked on 27/11/18
- README updated to include note that this repository is forked
- CHANGLOG created to track changes to repository from original fork

22
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,22 @@
# Contributing
To develop on this project, please fork the repo and clone into your `$GOPATH`.
Dependencies are **not** checked in so please download those separately.
Download the dependencies using [`dep`](https://github.com/golang/dep).
```bash
cd $GOPATH/src/github.com # Create this directory if it doesn't exist
git clone git@github.com:<YOUR_FORK>/oauth2_proxy pusher/oauth2_proxy
make dep
```
## Pull Requests and Issues
We track bugs and issues using Github.
If you find a bug, please open an Issue.
If you want to fix a bug, please fork, create a feature branch, fix the bug and
open a PR back to this repo.
Please mention the open bug issue number within your PR if applicable.

16
Dockerfile Normal file
View File

@ -0,0 +1,16 @@
FROM golang:1.10 AS builder
WORKDIR /go/src/github.com/pusher/oauth2_proxy
COPY . .
# Fetch dependencies
RUN go get -u github.com/golang/dep/cmd/dep
RUN dep ensure --vendor-only
# Build image
RUN ./configure && make clean oauth2_proxy
# Copy binary to debian
FROM debian:stretch
COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy
ENTRYPOINT ["/bin/oauth2_proxy"]

67
Gopkg.lock generated
View File

@ -2,118 +2,149 @@
[[projects]] [[projects]]
digest = "1:b24249f5a5e6fbe1eddc94b25973172339ccabeadef4779274f3ed0167c18812"
name = "cloud.google.com/go" name = "cloud.google.com/go"
packages = ["compute/metadata"] packages = ["compute/metadata"]
pruneopts = ""
revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613" revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613"
version = "v0.16.0" version = "v0.16.0"
[[projects]] [[projects]]
digest = "1:289dd4d7abfb3ad2b5f728fbe9b1d5c1bf7d265a3eb9ef92869af1f7baba4c7a"
name = "github.com/BurntSushi/toml" name = "github.com/BurntSushi/toml"
packages = ["."] packages = ["."]
pruneopts = ""
revision = "b26d9c308763d68093482582cea63d69be07a0f0" revision = "b26d9c308763d68093482582cea63d69be07a0f0"
version = "v0.3.0" version = "v0.3.0"
[[projects]] [[projects]]
digest = "1:512883404c2a99156e410e9880e3bb35ecccc0c07c1159eb204b5f3ef3c431b3"
name = "github.com/bitly/go-simplejson" name = "github.com/bitly/go-simplejson"
packages = ["."] packages = ["."]
pruneopts = ""
revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44" revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44"
version = "v0.5.0" version = "v0.5.0"
[[projects]] [[projects]]
branch = "v2" branch = "v2"
digest = "1:e5a238f8fa890e529d7e493849bbae8988c9e70344e4630cc4f9a11b00516afb"
name = "github.com/coreos/go-oidc" name = "github.com/coreos/go-oidc"
packages = ["."] packages = ["."]
pruneopts = ""
revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8" revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8"
[[projects]] [[projects]]
digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b"
name = "github.com/davecgh/go-spew" name = "github.com/davecgh/go-spew"
packages = ["spew"] packages = ["spew"]
pruneopts = ""
revision = "346938d642f2ec3594ed81d874461961cd0faa76" revision = "346938d642f2ec3594ed81d874461961cd0faa76"
version = "v1.1.0" version = "v1.1.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4"
name = "github.com/golang/protobuf" name = "github.com/golang/protobuf"
packages = ["proto"] packages = ["proto"]
pruneopts = ""
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
[[projects]] [[projects]]
digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f"
name = "github.com/mbland/hmacauth" name = "github.com/mbland/hmacauth"
packages = ["."] packages = ["."]
pruneopts = ""
revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb" revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb"
version = "1.0.2" version = "1.0.2"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:9408fb9c637c103010e5147469c232ce6b68edc840879cc730a2a15918e6cae8"
name = "github.com/mreiferson/go-options" name = "github.com/mreiferson/go-options"
packages = ["."] packages = ["."]
pruneopts = ""
revision = "77551d20752b54535462404ad9d877ebdb26e53d" revision = "77551d20752b54535462404ad9d877ebdb26e53d"
[[projects]] [[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib" name = "github.com/pmezard/go-difflib"
packages = ["difflib"] packages = ["difflib"]
pruneopts = ""
revision = "792786c7400a136282c1664665ae0a8db921c6c2" revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0" version = "v1.0.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:386e12afcfd8964907c92dffd106860c0dedd71dbefae14397b77b724a13343b"
name = "github.com/pquerna/cachecontrol" name = "github.com/pquerna/cachecontrol"
packages = [ packages = [
".", ".",
"cacheobject" "cacheobject",
] ]
pruneopts = ""
revision = "0dec1b30a0215bb68605dfc568e8855066c9202d" revision = "0dec1b30a0215bb68605dfc568e8855066c9202d"
[[projects]] [[projects]]
digest = "1:3926a4ec9a4ff1a072458451aa2d9b98acd059a45b38f7335d31e06c3d6a0159"
name = "github.com/stretchr/testify" name = "github.com/stretchr/testify"
packages = ["assert"] packages = ["assert"]
pruneopts = ""
revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0"
version = "v1.1.4" version = "v1.1.4"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:f6a006d27619a4d93bf9b66fe1999b8c8d1fa62bdc63af14f10fbe6fcaa2aa1a"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
packages = [ packages = [
"bcrypt", "bcrypt",
"blowfish", "blowfish",
"ed25519", "ed25519",
"ed25519/internal/edwards25519" "ed25519/internal/edwards25519",
] ]
pruneopts = ""
revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94" revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:130b1bec86c62e121967ee0c69d9c263dc2d3ffe6c7c9a82aca4071c4d068861"
name = "golang.org/x/net" name = "golang.org/x/net"
packages = [ packages = [
"context", "context",
"context/ctxhttp" "context/ctxhttp",
] ]
pruneopts = ""
revision = "9dfe39835686865bff950a07b394c12a98ddc811" revision = "9dfe39835686865bff950a07b394c12a98ddc811"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:4a61176e8386727e4847b21a5a2625ce56b9c518bc543a28226503e701265db0"
name = "golang.org/x/oauth2" name = "golang.org/x/oauth2"
packages = [ packages = [
".", ".",
"google", "google",
"internal", "internal",
"jws", "jws",
"jwt" "jwt",
] ]
pruneopts = ""
revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed"
name = "google.golang.org/api" name = "google.golang.org/api"
packages = [ packages = [
"admin/directory/v1", "admin/directory/v1",
"gensupport", "gensupport",
"googleapi", "googleapi",
"googleapi/internal/uritemplates" "googleapi/internal/uritemplates",
] ]
pruneopts = ""
revision = "8791354e7ab150705ede13637a18c1fcc16b62e8" revision = "8791354e7ab150705ede13637a18c1fcc16b62e8"
[[projects]] [[projects]]
digest = "1:934fb8966f303ede63aa405e2c8d7f0a427a05ea8df335dfdc1833dd4d40756f"
name = "google.golang.org/appengine" name = "google.golang.org/appengine"
packages = [ packages = [
".", ".",
@ -125,30 +156,48 @@
"internal/modules", "internal/modules",
"internal/remote_api", "internal/remote_api",
"internal/urlfetch", "internal/urlfetch",
"urlfetch" "urlfetch",
] ]
pruneopts = ""
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
version = "v1.0.0" version = "v1.0.0"
[[projects]] [[projects]]
name = "gopkg.in/fsnotify.v1" digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2"
name = "gopkg.in/fsnotify/fsnotify.v1"
packages = ["."] packages = ["."]
pruneopts = ""
revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6" revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6"
version = "v1.2.11" version = "v1.2.11"
[[projects]] [[projects]]
digest = "1:be4ed0a2b15944dd777a663681a39260ed05f9c4e213017ed2e2255622c8820c"
name = "gopkg.in/square/go-jose.v2" name = "gopkg.in/square/go-jose.v2"
packages = [ packages = [
".", ".",
"cipher", "cipher",
"json" "json",
] ]
pruneopts = ""
revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1"
version = "v2.1.3" version = "v2.1.3"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "b502c41a61115d14d6379be26b0300f65d173bdad852f0170d387ebf2d7ec173" input-imports = [
"github.com/BurntSushi/toml",
"github.com/bitly/go-simplejson",
"github.com/coreos/go-oidc",
"github.com/mbland/hmacauth",
"github.com/mreiferson/go-options",
"github.com/stretchr/testify/assert",
"golang.org/x/crypto/bcrypt",
"golang.org/x/oauth2",
"golang.org/x/oauth2/google",
"google.golang.org/api/admin/directory/v1",
"google.golang.org/api/googleapi",
"gopkg.in/fsnotify/fsnotify.v1",
]
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@ -3,10 +3,6 @@
# for detailed Gopkg.toml documentation. # for detailed Gopkg.toml documentation.
# #
[[constraint]]
name = "github.com/18F/hmacauth"
version = "~1.0.1"
[[constraint]] [[constraint]]
name = "github.com/BurntSushi/toml" name = "github.com/BurntSushi/toml"
version = "~0.3.0" version = "~0.3.0"
@ -36,7 +32,7 @@
name = "google.golang.org/api" name = "google.golang.org/api"
[[constraint]] [[constraint]]
name = "gopkg.in/fsnotify.v1" name = "gopkg.in/fsnotify/fsnotify.v1"
version = "~1.2.0" version = "~1.2.0"
[[constraint]] [[constraint]]

56
Makefile Normal file
View File

@ -0,0 +1,56 @@
include .env
BINARY := oauth2_proxy
VERSION := $(shell git describe --always --long --dirty --tags 2>/dev/null || echo "undefined")
.NOTPARALLEL:
.PHONY: all
all: dep lint $(BINARY)
.PHONY: clean
clean:
rm -rf release
rm -f $(BINARY)
.PHONY: distclean
distclean: clean
rm -rf vendor
BIN_DIR := $(GOPATH)/bin
GOMETALINTER := $(BIN_DIR)/gometalinter
$(GOMETALINTER):
$(GO) get -u github.com/alecthomas/gometalinter
gometalinter --install %> /dev/null
.PHONY: lint
lint: $(GOMETALINTER)
$(GOMETALINTER) --vendor --disable-all \
--enable=vet \
--enable=vetshadow \
--enable=golint \
--enable=ineffassign \
--enable=goconst \
--enable=deadcode \
--enable=gofmt \
--enable=goimports \
--tests ./...
.PHONY: dep
dep:
$(DEP) ensure --vendor-only
.PHONY: build
build: clean $(BINARY)
$(BINARY):
$(GO) build -ldflags="-X main.VERSION=${VERSION}" -o $(BINARY) github.com/pusher/oauth2_proxy
.PHONY: test
test: dep lint
$(GO) test -v -race $(go list ./... | grep -v /vendor/)
.PHONY: release
release: lint test
mkdir release
GOOS=darwin GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-darwin-amd64 github.com/pusher/oauth2_proxy
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-linux-amd64 github.com/pusher/oauth2_proxy

View File

@ -1,11 +1,13 @@
oauth2_proxy # oauth2_proxy
=================
A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others) A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others)
to validate accounts by email, domain or group. to validate accounts by email, domain or group.
[![Build Status](https://secure.travis-ci.org/bitly/oauth2_proxy.svg?branch=master)](http://travis-ci.org/bitly/oauth2_proxy) **Note:** This repository was forked from [bitly/OAuth2_Proxy](https://github.com/bitly/oauth2_proxy) on 27/11/2018.
Versions v3.0.0 and up are from this fork and will have diverged from any changes in the original fork.
A list of changes can be seen in the [CHANGELOG](CHANGELOG.md).
[![Build Status](https://secure.travis-ci.org/pusher/oauth2_proxy.svg?branch=master)](http://travis-ci.org/pusher/oauth2_proxy)
![Sign In Page](https://cloud.githubusercontent.com/assets/45028/4970624/7feb7dd8-6886-11e4-93e0-c9904af44ea8.png) ![Sign In Page](https://cloud.githubusercontent.com/assets/45028/4970624/7feb7dd8-6886-11e4-93e0-c9904af44ea8.png)
@ -15,12 +17,21 @@ to validate accounts by email, domain or group.
## Installation ## Installation
1. Download [Prebuilt Binary](https://github.com/bitly/oauth2_proxy/releases) (current release is `v2.2`) or build with `$ go get github.com/bitly/oauth2_proxy` which will put the binary in `$GOROOT/bin` 1. Choose how to deploy:
a. Download [Prebuilt Binary](https://github.com/pusher/oauth2_proxy/releases) (current release is `v2.2`)
b. Build with `$ go get github.com/pusher/oauth2_proxy` which will put the binary in `$GOROOT/bin`
c. Using the prebuilt docker image [quay.io/pusher/oauth2_proxy](https://quay.io/pusher/oauth2_proxy)
Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`. Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`.
``` ```
sha256sum -c sha256sum.txt 2>&1 | grep OK sha256sum -c sha256sum.txt 2>&1 | grep OK
oauth2_proxy-2.3.linux-amd64: OK oauth2_proxy-2.3.linux-amd64: OK
``` ```
2. Select a Provider and Register an OAuth Application with a Provider 2. Select a Provider and Register an OAuth Application with a Provider
3. Configure OAuth2 Proxy using config file, command line options, or environment variables 3. Configure OAuth2 Proxy using config file, command line options, or environment variables
4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx) 4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx)
@ -31,12 +42,12 @@ You will need to register an OAuth application with a Provider (Google, GitHub o
Valid providers are : Valid providers are :
* [Google](#google-auth-provider) *default* - [Google](#google-auth-provider) _default_
* [Azure](#azure-auth-provider) - [Azure](#azure-auth-provider)
* [Facebook](#facebook-auth-provider) - [Facebook](#facebook-auth-provider)
* [GitHub](#github-auth-provider) - [GitHub](#github-auth-provider)
* [GitLab](#gitlab-auth-provider) - [GitLab](#gitlab-auth-provider)
* [LinkedIn](#linkedin-auth-provider) - [LinkedIn](#linkedin-auth-provider)
The provider can be selected using the `provider` configuration value. The provider can be selected using the `provider` configuration value.
@ -50,14 +61,14 @@ For Google, the registration steps are:
4. In the left Nav pane, choose **"Credentials"** 4. In the left Nav pane, choose **"Credentials"**
5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save. 5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save.
6. In the center pane, choose **"Credentials"** tab. 6. In the center pane, choose **"Credentials"** tab.
* Open the **"New credentials"** drop down - Open the **"New credentials"** drop down
* Choose **"OAuth client ID"** - Choose **"OAuth client ID"**
* Choose **"Web application"** - Choose **"Web application"**
* Application name is freeform, choose something appropriate - Application name is freeform, choose something appropriate
* Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com` - Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com`
* Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback` - Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback`
* Choose **"Create"** - Choose **"Create"**
4. Take note of the **Client ID** and **Client Secret** 7. Take note of the **Client ID** and **Client Secret**
It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized.
@ -68,15 +79,17 @@ It's recommended to refresh sessions on a short interval (1h) with `cookie-refre
3. Under "APIs & Auth", choose APIs. 3. Under "APIs & Auth", choose APIs.
4. Click on Admin SDK and then Enable API. 4. Click on Admin SDK and then Enable API.
5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes: 5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes:
``` ```
https://www.googleapis.com/auth/admin.directory.group.readonly https://www.googleapis.com/auth/admin.directory.group.readonly
https://www.googleapis.com/auth/admin.directory.user.readonly https://www.googleapis.com/auth/admin.directory.user.readonly
``` ```
6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access. 6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access.
7. Create or choose an existing administrative email address on the Gmail domain to assign to the ```google-admin-email``` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why. 7. Create or choose an existing administrative email address on the Gmail domain to assign to the `google-admin-email` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why.
8. Create or choose an existing email group and set that email to the ```google-group``` flag. You can pass multiple instances of this flag with different groups 8. Create or choose an existing email group and set that email to the `google-group` flag. You can pass multiple instances of this flag with different groups
and the user will be checked against all the provided groups. and the user will be checked against all the provided groups.
9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the ```google-service-account-json``` flag. 9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the `google-service-account-json` flag.
10. Restart oauth2_proxy. 10. Restart oauth2_proxy.
Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ).
@ -89,7 +102,6 @@ Note: The user is checked against the group members list on initial authenticati
The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in. The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in.
### Facebook Auth Provider ### Facebook Auth Provider
1. Create a new FB App from <https://developers.facebook.com/> 1. Create a new FB App from <https://developers.facebook.com/>
@ -121,15 +133,14 @@ If you are using self-hosted GitLab, make sure you set the following to the appr
-redeem-url="<your gitlab url>/oauth/token" -redeem-url="<your gitlab url>/oauth/token"
-validate-url="<your gitlab url>/api/v4/user" -validate-url="<your gitlab url>/api/v4/user"
### LinkedIn Auth Provider ### LinkedIn Auth Provider
For LinkedIn, the registration steps are: For LinkedIn, the registration steps are:
1. Create a new project: https://www.linkedin.com/secure/developer 1. Create a new project: https://www.linkedin.com/secure/developer
2. In the OAuth User Agreement section: 2. In the OAuth User Agreement section:
* In default scope, select r_basicprofile and r_emailaddress. - In default scope, select r_basicprofile and r_emailaddress.
* In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback` - In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback`
3. Fill in the remaining required fields and Save. 3. Fill in the remaining required fields and Save.
4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key** 4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key**
@ -253,7 +264,7 @@ The following environment variables can be used in place of the corresponding co
There are two recommended configurations. There are two recommended configurations.
1) Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`. 1. Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`.
The command line to run `oauth2_proxy` in this configuration would look like this: The command line to run `oauth2_proxy` in this configuration would look like this:
@ -270,8 +281,7 @@ The command line to run `oauth2_proxy` in this configuration would look like thi
--client-secret=... --client-secret=...
``` ```
2. Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or ....
2) Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or ....
Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an
external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or
@ -321,12 +331,12 @@ The command line to run `oauth2_proxy` in this configuration would look like thi
OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable. OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable.
* /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info - /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info
* /ping - returns an 200 OK response - /ping - returns an 200 OK response
* /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) - /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies)
* /oauth2/start - a URL that will redirect to start the OAuth cycle - /oauth2/start - a URL that will redirect to start the OAuth cycle
* /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url. - /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url.
* /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request) - /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request)
## Request signatures ## Request signatures
@ -341,9 +351,9 @@ in `oauthproxy.go`](./oauthproxy.go).
For more information about HMAC request signature validation, read the For more information about HMAC request signature validation, read the
following: following:
* [Amazon Web Services: Signing and Authenticating REST - [Amazon Web Services: Signing and Authenticating REST
Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html)
* [rc3.org: Using HMAC to authenticate Web service - [rc3.org: Using HMAC to authenticate Web service
requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/) requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/)
## Logging Format ## Logging Format
@ -417,3 +427,7 @@ server {
} }
} }
``` ```
## Contributing
Please see our [Contributing](CONTRIBUTING.md) guidelines.

View File

@ -10,6 +10,7 @@ import (
"github.com/bitly/go-simplejson" "github.com/bitly/go-simplejson"
) )
// Request parses the request body into a simplejson.Json object
func Request(req *http.Request) (*simplejson.Json, error) { func Request(req *http.Request) (*simplejson.Json, error) {
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -32,7 +33,8 @@ func Request(req *http.Request) (*simplejson.Json, error) {
return data, nil return data, nil
} }
func RequestJson(req *http.Request, v interface{}) error { // RequestJSON parses the request body into the given interface
func RequestJSON(req *http.Request, v interface{}) error {
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
log.Printf("%s %s %s", req.Method, req.URL, err) log.Printf("%s %s %s", req.Method, req.URL, err)
@ -50,6 +52,7 @@ func RequestJson(req *http.Request, v interface{}) error {
return json.Unmarshal(body, v) return json.Unmarshal(body, v)
} }
// RequestUnparsedResponse performs a GET and returns the raw response object
func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {

View File

@ -1,20 +1,21 @@
package api package api
import ( import (
"github.com/bitly/go-simplejson"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/bitly/go-simplejson"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func testBackend(response_code int, payload string) *httptest.Server { func testBackend(responseCode int, payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(response_code) w.WriteHeader(responseCode)
w.Write([]byte(payload)) w.Write([]byte(payload))
})) }))
} }

137
configure vendored Executable file
View File

@ -0,0 +1,137 @@
#!/usr/bin/env bash
RED='\033[0;31m'
GREEN='\033[0;32m'
BLUE='\033[0;34m'
NC='\033[0m'
declare -A tools=()
declare -A desired=()
for arg in "$@"; do
case ${arg%%=*} in
"--with-go")
desired[go]="${arg##*=}"
;;
"--with-dep")
desired[dep]="${arg##*=}"
;;
"--help")
printf "${GREEN}$0${NC}\n"
printf " available options:\n"
printf " --with-dep=${BLUE}<path_to_dep_binary>${NC}\n"
printf " --with-go=${BLUE}<path_to_go_binary>${NC}\n"
exit 0
;;
*)
echo "Unknown option: $arg"
exit 2
;;
esac
done
vercomp () {
if [[ $1 == $2 ]]
then
return 0
fi
local IFS=.
local i ver1=($1) ver2=($2)
# fill empty fields in ver1 with zeros
for ((i=${#ver1[@]}; i<${#ver2[@]}; i++))
do
ver1[i]=0
done
for ((i=0; i<${#ver1[@]}; i++))
do
if [[ -z ${ver2[i]} ]]
then
# fill empty fields in ver2 with zeros
ver2[i]=0
fi
if ((10#${ver1[i]} > 10#${ver2[i]}))
then
return 1
fi
if ((10#${ver1[i]} < 10#${ver2[i]}))
then
return 2
fi
done
return 0
}
check_for() {
echo -n "Checking for $1... "
if ! [ -z "${desired[$1]}" ]; then
TOOL_PATH="${desired[$1]}"
else
TOOL_PATH=$(command -v $1)
fi
if ! [ -x "$TOOL_PATH" -a -f "$TOOL_PATH" ]; then
printf "${RED}not found${NC}\n"
cd -
exit 1
else
printf "${GREEN}found${NC}\n"
tools[$1]=$TOOL_PATH
fi
}
check_go_version() {
echo -n "Checking go version... "
GO_VERSION=$(${tools[go]} version | ${tools[awk]} '{where = match($0, /[0-9]\.[0-9]+\.[0-9]*/); if (where != 0) print substr($0, RSTART, RLENGTH)}')
vercomp $GO_VERSION 1.9
case $? in
0) ;&
1)
printf "${GREEN}"
echo $GO_VERSION
printf "${NC}"
;;
2)
printf "${RED}"
echo "$GO_VERSION < 1.9"
exit 1
;;
esac
}
check_docker_version() {
echo -n "Checking docker version... "
DOCKER_VERSION=$(${tools[docker]} version | ${tools[awk]})
}
check_go_env() {
echo -n "Checking \$GOPATH... "
if [ -z "$GOPATH" ]; then
printf "${RED}invalid${NC} - GOPATH not set\n"
exit 1
fi
printf "${GREEN}valid${NC} - $GOPATH\n"
}
cd ${0%/*}
if [ ! -f .env ]; then
rm .env
fi
check_for make
check_for awk
check_for go
check_go_version
check_go_env
check_for dep
echo
cat <<- EOF > .env
MAKE := "${tools[make]}"
GO := "${tools[go]}"
DEP := "${tools[dep]}"
EOF
echo "Environment configuration written to .env"
cd - > /dev/null

View File

@ -24,10 +24,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
} }
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
const token = "my access token" const token = "my access token"
secret, err := base64.URLEncoding.DecodeString(secret_b64) secret, err := base64.URLEncoding.DecodeString(secretBase64)
assert.Equal(t, nil, err)
c, err := NewCipher([]byte(secret)) c, err := NewCipher([]byte(secret))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// Nonce generates a random 16 byte string to be used as a nonce
func Nonce() (nonce string, err error) { func Nonce() (nonce string, err error) {
b := make([]byte, 16) b := make([]byte, 16)
_, err = rand.Read(b) _, err = rand.Read(b)

View File

@ -6,8 +6,14 @@ import (
"strings" "strings"
) )
// EnvOptions holds program options loaded from the process environment
type EnvOptions map[string]interface{} type EnvOptions map[string]interface{}
// LoadEnvForStruct loads environment variables for each field in an options
// struct passed into it.
//
// Fields in the options struct must have an `env` and `cfg` tag to be read
// from the environment
func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { func (cfg EnvOptions) LoadEnvForStruct(options interface{}) {
val := reflect.ValueOf(options).Elem() val := reflect.ValueOf(options).Elem()
typ := val.Type() typ := val.Type()

View File

@ -14,10 +14,12 @@ import (
// Lookup passwords in a htpasswd file // Lookup passwords in a htpasswd file
// Passwords must be generated with -B for bcrypt or -s for SHA1. // Passwords must be generated with -B for bcrypt or -s for SHA1.
// HtpasswdFile represents the structure of an htpasswd file
type HtpasswdFile struct { type HtpasswdFile struct {
Users map[string]string Users map[string]string
} }
// NewHtpasswdFromFile constructs an HtpasswdFile from the file at the path given
func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
r, err := os.Open(path) r, err := os.Open(path)
if err != nil { if err != nil {
@ -27,13 +29,14 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
return NewHtpasswd(r) return NewHtpasswd(r)
} }
// NewHtpasswd consctructs an HtpasswdFile from an io.Reader (opened file)
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
csv_reader := csv.NewReader(file) csvReader := csv.NewReader(file)
csv_reader.Comma = ':' csvReader.Comma = ':'
csv_reader.Comment = '#' csvReader.Comment = '#'
csv_reader.TrimLeadingSpace = true csvReader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll() records, err := csvReader.ReadAll()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -44,6 +47,7 @@ func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
return h, nil return h, nil
} }
// Validate checks a users password against the HtpasswdFile entries
func (h *HtpasswdFile) Validate(user string, password string) bool { func (h *HtpasswdFile) Validate(user string, password string) bool {
realPassword, exists := h.Users[user] realPassword, exists := h.Users[user]
if !exists { if !exists {

View File

@ -20,6 +20,7 @@ func TestSHA(t *testing.T) {
func TestBcrypt(t *testing.T) { func TestBcrypt(t *testing.T) {
hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1)
assert.Equal(t, err, nil)
hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)

16
http.go
View File

@ -9,11 +9,13 @@ import (
"time" "time"
) )
// Server represents an HTTP server
type Server struct { type Server struct {
Handler http.Handler Handler http.Handler
Opts *Options Opts *Options
} }
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
func (s *Server) ListenAndServe() { func (s *Server) ListenAndServe() {
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
s.ServeHTTPS() s.ServeHTTPS()
@ -22,13 +24,14 @@ func (s *Server) ListenAndServe() {
} }
} }
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() { func (s *Server) ServeHTTP() {
httpAddress := s.Opts.HttpAddress HTTPAddress := s.Opts.HTTPAddress
scheme := "" var scheme string
i := strings.Index(httpAddress, "://") i := strings.Index(HTTPAddress, "://")
if i > -1 { if i > -1 {
scheme = httpAddress[0:i] scheme = HTTPAddress[0:i]
} }
var networkType string var networkType string
@ -39,7 +42,7 @@ func (s *Server) ServeHTTP() {
networkType = scheme networkType = scheme
} }
slice := strings.SplitN(httpAddress, "//", 2) slice := strings.SplitN(HTTPAddress, "//", 2)
listenAddr := slice[len(slice)-1] listenAddr := slice[len(slice)-1]
listener, err := net.Listen(networkType, listenAddr) listener, err := net.Listen(networkType, listenAddr)
@ -57,8 +60,9 @@ func (s *Server) ServeHTTP() {
log.Printf("HTTP: closing %s", listener.Addr()) log.Printf("HTTP: closing %s", listener.Addr())
} }
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (s *Server) ServeHTTPS() { func (s *Server) ServeHTTPS() {
addr := s.Opts.HttpsAddress addr := s.Opts.HTTPSAddress
config := &tls.Config{ config := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS12,

View File

@ -27,10 +27,13 @@ type responseLogger struct {
authInfo string authInfo string
} }
// Header returns the ResponseWriter's Header
func (l *responseLogger) Header() http.Header { func (l *responseLogger) Header() http.Header {
return l.w.Header() return l.w.Header()
} }
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
// Header
func (l *responseLogger) ExtractGAPMetadata() { func (l *responseLogger) ExtractGAPMetadata() {
upstream := l.w.Header().Get("GAP-Upstream-Address") upstream := l.w.Header().Get("GAP-Upstream-Address")
if upstream != "" { if upstream != "" {
@ -44,6 +47,7 @@ func (l *responseLogger) ExtractGAPMetadata() {
} }
} }
// Write writes the response using the ResponseWriter
func (l *responseLogger) Write(b []byte) (int, error) { func (l *responseLogger) Write(b []byte) (int, error) {
if l.status == 0 { if l.status == 0 {
// The status will be StatusOK if WriteHeader has not been called yet // The status will be StatusOK if WriteHeader has not been called yet
@ -55,16 +59,19 @@ func (l *responseLogger) Write(b []byte) (int, error) {
return size, err return size, err
} }
// WriteHeader writes the status code for the Response
func (l *responseLogger) WriteHeader(s int) { func (l *responseLogger) WriteHeader(s int) {
l.ExtractGAPMetadata() l.ExtractGAPMetadata()
l.w.WriteHeader(s) l.w.WriteHeader(s)
l.status = s l.status = s
} }
// Status returns the response status code
func (l *responseLogger) Status() int { func (l *responseLogger) Status() int {
return l.status return l.status
} }
// Size returns teh response size
func (l *responseLogger) Size() int { func (l *responseLogger) Size() int {
return l.size return l.size
} }
@ -94,6 +101,7 @@ type loggingHandler struct {
logTemplate *template.Template logTemplate *template.Template
} }
// LoggingHandler provides an http.Handler which logs requests to the HTTP server
func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler { func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler {
return loggingHandler{ return loggingHandler{
writer: out, writer: out,

View File

@ -84,7 +84,7 @@ func main() {
flagSet.Parse(os.Args[1:]) flagSet.Parse(os.Args[1:])
if *showVersion { if *showVersion {
fmt.Printf("oauth2_proxy v%s (built with %s)\n", VERSION, runtime.Version()) fmt.Printf("oauth2_proxy %s (built with %s)\n", VERSION, runtime.Version())
return return
} }

View File

@ -14,14 +14,23 @@ import (
"strings" "strings"
"time" "time"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/bitly/oauth2_proxy/providers"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/providers"
) )
const SignatureHeader = "GAP-Signature" const (
// SignatureHeader is the name of the request header containing the GAP Signature
// Part of hmacauth
SignatureHeader = "GAP-Signature"
var SignatureHeaders []string = []string{ httpScheme = "http"
httpsScheme = "https"
)
// SignatureHeaders contains the headers to be signed by the hmac algorithm
// Part of hmacauth
var SignatureHeaders = []string{
"Content-Length", "Content-Length",
"Content-Md5", "Content-Md5",
"Content-Type", "Content-Type",
@ -34,13 +43,14 @@ var SignatureHeaders []string = []string{
"Gap-Auth", "Gap-Auth",
} }
// OAuthProxy is the main authentication proxy
type OAuthProxy struct { type OAuthProxy struct {
CookieSeed string CookieSeed string
CookieName string CookieName string
CSRFCookieName string CSRFCookieName string
CookieDomain string CookieDomain string
CookieSecure bool CookieSecure bool
CookieHttpOnly bool CookieHTTPOnly bool
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration CookieRefresh time.Duration
Validator func(string) bool Validator func(string) bool
@ -74,12 +84,15 @@ type OAuthProxy struct {
Footer string Footer string
} }
// UpstreamProxy represents an upstream server to proxy to
type UpstreamProxy struct { type UpstreamProxy struct {
upstream string upstream string
handler http.Handler handler http.Handler
auth hmacauth.HmacAuth auth hmacauth.HmacAuth
} }
// ServeHTTP proxies requests to the upstream provider while signing the
// request headers
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("GAP-Upstream-Address", u.upstream) w.Header().Set("GAP-Upstream-Address", u.upstream)
if u.auth != nil { if u.auth != nil {
@ -89,9 +102,12 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
u.handler.ServeHTTP(w, r) u.handler.ServeHTTP(w, r)
} }
// NewReverseProxy creates a new reverse proxy for proxying requests to upstream
// servers
func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) {
return httputil.NewSingleHostReverseProxy(target) return httputil.NewSingleHostReverseProxy(target)
} }
func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) {
director := proxy.Director director := proxy.Director
proxy.Director = func(req *http.Request) { proxy.Director = func(req *http.Request) {
@ -102,6 +118,7 @@ func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) {
req.URL.RawQuery = "" req.URL.RawQuery = ""
} }
} }
func setProxyDirector(proxy *httputil.ReverseProxy) { func setProxyDirector(proxy *httputil.ReverseProxy) {
director := proxy.Director director := proxy.Director
proxy.Director = func(req *http.Request) { proxy.Director = func(req *http.Request) {
@ -111,10 +128,13 @@ func setProxyDirector(proxy *httputil.ReverseProxy) {
req.URL.RawQuery = "" req.URL.RawQuery = ""
} }
} }
// NewFileServer creates a http.Handler to serve files from the filesystem
func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { func NewFileServer(path string, filesystemPath string) (proxy http.Handler) {
return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath)))
} }
// NewOAuthProxy creates a new instance of OOuthProxy from the options provided
func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
serveMux := http.NewServeMux() serveMux := http.NewServeMux()
var auth hmacauth.HmacAuth var auth hmacauth.HmacAuth
@ -125,7 +145,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
for _, u := range opts.proxyURLs { for _, u := range opts.proxyURLs {
path := u.Path path := u.Path
switch u.Scheme { switch u.Scheme {
case "http", "https": case httpScheme, httpsScheme:
u.Path = "" u.Path = ""
log.Printf("mapping path %q => upstream %q", path, u) log.Printf("mapping path %q => upstream %q", path, u)
proxy := NewReverseProxy(u) proxy := NewReverseProxy(u)
@ -160,7 +180,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
refresh = fmt.Sprintf("after %s", opts.CookieRefresh) refresh = fmt.Sprintf("after %s", opts.CookieRefresh)
} }
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh)
var cipher *cookie.Cipher var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
@ -177,7 +197,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
CookieSeed: opts.CookieSecret, CookieSeed: opts.CookieSecret,
CookieDomain: opts.CookieDomain, CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHttpOnly: opts.CookieHttpOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieRefresh: opts.CookieRefresh, CookieRefresh: opts.CookieRefresh,
Validator: validator, Validator: validator,
@ -209,6 +229,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
} }
} }
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
// redirect clients to once authenticated
func (p *OAuthProxy) GetRedirectURI(host string) string { func (p *OAuthProxy) GetRedirectURI(host string) string {
// default to the request Host if not set // default to the request Host if not set
if p.redirectURL.Host != "" { if p.redirectURL.Host != "" {
@ -218,9 +240,9 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
u = *p.redirectURL u = *p.redirectURL
if u.Scheme == "" { if u.Scheme == "" {
if p.CookieSecure { if p.CookieSecure {
u.Scheme = "https" u.Scheme = httpsScheme
} else { } else {
u.Scheme = "http" u.Scheme = httpScheme
} }
} }
u.Host = host u.Host = host
@ -254,6 +276,8 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
return return
} }
// MakeSessionCookie creates an http.Cookie containing the authenticated user's
// authentication details
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
if value != "" { if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
@ -265,6 +289,7 @@ func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expirati
return p.makeCookie(req, p.CookieName, value, expiration, now) return p.makeCookie(req, p.CookieName, value, expiration, now)
} }
// MakeCSRFCookie creates a cookie for CSRF
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
} }
@ -285,20 +310,25 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
Value: value, Value: value,
Path: "/", Path: "/",
Domain: p.CookieDomain, Domain: p.CookieDomain,
HttpOnly: p.CookieHttpOnly, HttpOnly: p.CookieHTTPOnly,
Secure: p.CookieSecure, Secure: p.CookieSecure,
Expires: now.Add(expiration), Expires: now.Add(expiration),
} }
} }
// ClearCSRFCookie creates a cookie to unset the CSRF cookie stored in the user's
// session
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
} }
// SetCSRFCookie adds a CSRF cookie to the response
func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
} }
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clr) http.SetCookie(rw, clr)
@ -311,10 +341,12 @@ func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Reques
} }
} }
// SetSessionCookie adds the user's session cookie to the response
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
} }
// LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
var age time.Duration var age time.Duration
c, err := req.Cookie(p.CookieName) c, err := req.Cookie(p.CookieName)
@ -336,6 +368,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
return session, age, nil return session, age, nil
} }
// SaveSession creates a new session cookie value and sets this on the response
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher) value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil { if err != nil {
@ -345,16 +378,19 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p
return nil return nil
} }
// RobotsTxt disallows scraping pages from the OAuthProxy
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "User-agent: *\nDisallow: /") fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
} }
// PingPage responds 200 OK to requests
func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { func (p *OAuthProxy) PingPage(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK") fmt.Fprintf(rw, "OK")
} }
// ErrorPage writes an error response
func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) {
log.Printf("ErrorPage %d %s %s", code, title, message) log.Printf("ErrorPage %d %s %s", code, title, message)
rw.WriteHeader(code) rw.WriteHeader(code)
@ -370,16 +406,17 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
p.templates.ExecuteTemplate(rw, "error.html", t) p.templates.ExecuteTemplate(rw, "error.html", t)
} }
// SignInPage writes the sing in template to the response
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
p.ClearSessionCookie(rw, req) p.ClearSessionCookie(rw, req)
rw.WriteHeader(code) rw.WriteHeader(code)
redirect_url := req.URL.RequestURI() redirecURL := req.URL.RequestURI()
if req.Header.Get("X-Auth-Request-Redirect") != "" { if req.Header.Get("X-Auth-Request-Redirect") != "" {
redirect_url = req.Header.Get("X-Auth-Request-Redirect") redirecURL = req.Header.Get("X-Auth-Request-Redirect")
} }
if redirect_url == p.SignInPath { if redirecURL == p.SignInPath {
redirect_url = "/" redirecURL = "/"
} }
t := struct { t := struct {
@ -394,7 +431,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
ProviderName: p.provider.Data().ProviderName, ProviderName: p.provider.Data().ProviderName,
SignInMessage: p.SignInMessage, SignInMessage: p.SignInMessage,
CustomLogin: p.displayCustomLoginForm(), CustomLogin: p.displayCustomLoginForm(),
Redirect: redirect_url, Redirect: redirecURL,
Version: VERSION, Version: VERSION,
ProxyPrefix: p.ProxyPrefix, ProxyPrefix: p.ProxyPrefix,
Footer: template.HTML(p.Footer), Footer: template.HTML(p.Footer),
@ -402,6 +439,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
p.templates.ExecuteTemplate(rw, "sign_in.html", t) p.templates.ExecuteTemplate(rw, "sign_in.html", t)
} }
// ManualSignIn handles basic auth logins to the proxy
func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) {
if req.Method != "POST" || p.HtpasswdFile == nil { if req.Method != "POST" || p.HtpasswdFile == nil {
return "", false return "", false
@ -419,6 +457,8 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
return "", false return "", false
} }
// GetRedirect reads the query parameter to get the URL to redirect clients to
// once authenticated with the OAuthProxy
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
err = req.ParseForm() err = req.ParseForm()
if err != nil { if err != nil {
@ -433,11 +473,13 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error)
return return
} }
// IsWhitelistedRequest is used to check if auth should be skipped for this request
func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) { func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) {
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path) return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path)
} }
// IsWhitelistedPath is used to check if the request path is allowed without auth
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
for _, u := range p.compiledRegex { for _, u := range p.compiledRegex {
ok = u.MatchString(path) ok = u.MatchString(path)
@ -479,6 +521,7 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
} }
// SignIn serves a page prompting users to sign in
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req) redirect, err := p.GetRedirect(req)
if err != nil { if err != nil {
@ -500,11 +543,13 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
} }
} }
// SignOut sends a response to clear the authentication cookie
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
p.ClearSessionCookie(rw, req) p.ClearSessionCookie(rw, req)
http.Redirect(rw, req, "/", 302) http.Redirect(rw, req, "/", 302)
} }
// OAuthStart starts the OAuth2 authentication flow
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
nonce, err := cookie.Nonce() nonce, err := cookie.Nonce()
if err != nil { if err != nil {
@ -521,6 +566,8 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302)
} }
// OAuthCallback is the OAuth2 authentication flow callback that finishes the
// OAuth2 authentication flow
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
remoteAddr := getRemoteAddr(req) remoteAddr := getRemoteAddr(req)
@ -582,6 +629,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
} }
} }
// AuthenticateOnly checks whether the user is currently logged in
func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req) status := p.Authenticate(rw, req)
if status == http.StatusAccepted { if status == http.StatusAccepted {
@ -591,6 +639,8 @@ func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request)
} }
} }
// Proxy proxies the user request if the user is authenticated else it prompts
// them to authenticate
func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req) status := p.Authenticate(rw, req)
if status == http.StatusInternalServerError { if status == http.StatusInternalServerError {
@ -607,6 +657,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
} }
} }
// Authenticate checks whether a user is authenticated
func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int { func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int {
var saveSession, clearSession, revalidated bool var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req) remoteAddr := getRemoteAddr(req)
@ -620,7 +671,8 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
saveSession = true saveSession = true
} }
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { var ok bool
if ok, err = p.provider.RefreshSessionIfNeeded(session); err != nil {
log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true clearSession = true
session = nil session = nil
@ -653,7 +705,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
} }
if saveSession && session != nil { if saveSession && session != nil {
err := p.SaveSession(rw, req, session) err = p.SaveSession(rw, req, session)
if err != nil { if err != nil {
log.Printf("%s %s", remoteAddr, err) log.Printf("%s %s", remoteAddr, err)
return http.StatusInternalServerError return http.StatusInternalServerError
@ -706,6 +758,8 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
return http.StatusAccepted return http.StatusAccepted
} }
// CheckBasicAuth checks the requests Authorization header for basic auth
// credentials and authenticates these against the proxies HtpasswdFile
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
if p.HtpasswdFile == nil { if p.HtpasswdFile == nil {
return nil, nil return nil, nil

View File

@ -15,8 +15,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/bitly/oauth2_proxy/providers"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -98,28 +98,28 @@ type TestProvider struct {
ValidToken bool ValidToken bool
} }
func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider { func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
return &TestProvider{ return &TestProvider{
ProviderData: &providers.ProviderData{ ProviderData: &providers.ProviderData{
ProviderName: "Test Provider", ProviderName: "Test Provider",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: provider_url.Host, Host: providerURL.Host,
Path: "/oauth/authorize", Path: "/oauth/authorize",
}, },
RedeemURL: &url.URL{ RedeemURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: provider_url.Host, Host: providerURL.Host,
Path: "/oauth/token", Path: "/oauth/token",
}, },
ProfileURL: &url.URL{ ProfileURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: provider_url.Host, Host: providerURL.Host,
Path: "/api/v1/profile", Path: "/api/v1/profile",
}, },
Scope: "profile.email", Scope: "profile.email",
}, },
EmailAddress: email_address, EmailAddress: emailAddress,
} }
} }
@ -132,11 +132,10 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo
} }
func TestBasicAuthPassword(t *testing.T) { func TestBasicAuthPassword(t *testing.T) {
provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r) log.Printf("%#v", r)
url := r.URL var payload string
payload := "" switch r.URL.Path {
switch url.Path {
case "/oauth/token": case "/oauth/token":
payload = `{"access_token": "my_auth_token"}` payload = `{"access_token": "my_auth_token"}`
default: default:
@ -149,7 +148,7 @@ func TestBasicAuthPassword(t *testing.T) {
w.Write([]byte(payload)) w.Write([]byte(payload))
})) }))
opts := NewOptions() opts := NewOptions()
opts.Upstreams = append(opts.Upstreams, provider_server.URL) opts.Upstreams = append(opts.Upstreams, providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES // The CookieSecret must be 32 bytes in order to create the AES
// cipher. // cipher.
opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -161,13 +160,13 @@ func TestBasicAuthPassword(t *testing.T) {
opts.BasicAuthPassword = "This is a secure password" opts.BasicAuthPassword = "This is a secure password"
opts.Validate() opts.Validate()
provider_url, _ := url.Parse(provider_server.URL) providerURL, _ := url.Parse(providerServer.URL)
const email_address = "michael.bland@gsa.gov" const emailAddress = "michael.bland@gsa.gov"
const user_name = "michael.bland" const username = "michael.bland"
opts.provider = NewTestProvider(provider_url, email_address) opts.provider = NewTestProvider(providerURL, emailAddress)
proxy := NewOAuthProxy(opts, func(email string) bool { proxy := NewOAuthProxy(opts, func(email string) bool {
return email == email_address return email == emailAddress
}) })
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -182,10 +181,10 @@ func TestBasicAuthPassword(t *testing.T) {
cookieName := proxy.CookieName cookieName := proxy.CookieName
var value string var value string
key_prefix := cookieName + "=" keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") { for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix) value = strings.TrimPrefix(field, keyPrefix)
if value != field { if value != field {
break break
} else { } else {
@ -206,13 +205,13 @@ func TestBasicAuthPassword(t *testing.T) {
rw = httptest.NewRecorder() rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req) proxy.ServeHTTP(rw, req)
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword)) expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword))
assert.Equal(t, expectedHeader, rw.Body.String()) assert.Equal(t, expectedHeader, rw.Body.String())
provider_server.Close() providerServer.Close()
} }
type PassAccessTokenTest struct { type PassAccessTokenTest struct {
provider_server *httptest.Server providerServer *httptest.Server
proxy *OAuthProxy proxy *OAuthProxy
opts *Options opts *Options
} }
@ -224,12 +223,11 @@ type PassAccessTokenTestOptions struct {
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
t := &PassAccessTokenTest{} t := &PassAccessTokenTest{}
t.provider_server = httptest.NewServer( t.providerServer = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r) log.Printf("%#v", r)
url := r.URL var payload string
payload := "" switch r.URL.Path {
switch url.Path {
case "/oauth/token": case "/oauth/token":
payload = `{"access_token": "my_auth_token"}` payload = `{"access_token": "my_auth_token"}`
default: default:
@ -243,7 +241,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
})) }))
t.opts = NewOptions() t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES // The CookieSecret must be 32 bytes in order to create the AES
// cipher. // cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -253,21 +251,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
t.opts.PassAccessToken = opts.PassAccessToken t.opts.PassAccessToken = opts.PassAccessToken
t.opts.Validate() t.opts.Validate()
provider_url, _ := url.Parse(t.provider_server.URL) providerURL, _ := url.Parse(t.providerServer.URL)
const email_address = "michael.bland@gsa.gov" const emailAddress = "michael.bland@gsa.gov"
t.opts.provider = NewTestProvider(provider_url, email_address) t.opts.provider = NewTestProvider(providerURL, emailAddress)
t.proxy = NewOAuthProxy(t.opts, func(email string) bool { t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
return email == email_address return email == emailAddress
}) })
return t return t
} }
func (pat_test *PassAccessTokenTest) Close() { func (patTest *PassAccessTokenTest) Close() {
pat_test.provider_server.Close() patTest.providerServer.Close()
} }
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
cookie string) { cookie string) {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
@ -275,18 +273,18 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
if err != nil { if err != nil {
return 0, "" return 0, ""
} }
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
pat_test.proxy.ServeHTTP(rw, req) patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][1] return rw.Code, rw.HeaderMap["Set-Cookie"][1]
} }
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) {
cookieName := pat_test.proxy.CookieName cookieName := patTest.proxy.CookieName
var value string var value string
key_prefix := cookieName + "=" keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") { for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix) value = strings.TrimPrefix(field, keyPrefix)
if value != field { if value != field {
break break
} else { } else {
@ -310,18 +308,18 @@ func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code i
}) })
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
pat_test.proxy.ServeHTTP(rw, req) patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String() return rw.Code, rw.Body.String()
} }
func TestForwardAccessTokenUpstream(t *testing.T) { func TestForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true, PassAccessToken: true,
}) })
defer pat_test.Close() defer patTest.Close()
// A successful validation will redirect and set the auth cookie. // A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint() code, cookie := patTest.getCallbackEndpoint()
if code != 302 { if code != 302 {
t.Fatalf("expected 302; got %d", code) t.Fatalf("expected 302; got %d", code)
} }
@ -330,7 +328,7 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
// Now we make a regular request; the access_token from the cookie is // Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is // forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body. // read by the test provider server and written in the response body.
code, payload := pat_test.getRootEndpoint(cookie) code, payload := patTest.getRootEndpoint(cookie)
if code != 200 { if code != 200 {
t.Fatalf("expected 200; got %d", code) t.Fatalf("expected 200; got %d", code)
} }
@ -338,13 +336,13 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
} }
func TestDoNotForwardAccessTokenUpstream(t *testing.T) { func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false, PassAccessToken: false,
}) })
defer pat_test.Close() defer patTest.Close()
// A successful validation will redirect and set the auth cookie. // A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint() code, cookie := patTest.getCallbackEndpoint()
if code != 302 { if code != 302 {
t.Fatalf("expected 302; got %d", code) t.Fatalf("expected 302; got %d", code)
} }
@ -352,7 +350,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
// Now we make a regular request, but the access token header should // Now we make a regular request, but the access token header should
// not be present. // not be present.
code, payload := pat_test.getRootEndpoint(cookie) code, payload := patTest.getRootEndpoint(cookie)
if code != 200 { if code != 200 {
t.Fatalf("expected 200; got %d", code) t.Fatalf("expected 200; got %d", code)
} }
@ -362,47 +360,47 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
type SignInPageTest struct { type SignInPageTest struct {
opts *Options opts *Options
proxy *OAuthProxy proxy *OAuthProxy
sign_in_regexp *regexp.Regexp signInRegexp *regexp.Regexp
sign_in_provider_regexp *regexp.Regexp signInProviderRegexp *regexp.Regexp
} }
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
const signInSkipProvider = `>Found<` const signInSkipProvider = `>Found<`
func NewSignInPageTest(skipProvider bool) *SignInPageTest { func NewSignInPageTest(skipProvider bool) *SignInPageTest {
var sip_test SignInPageTest var sipTest SignInPageTest
sip_test.opts = NewOptions() sipTest.opts = NewOptions()
sip_test.opts.CookieSecret = "foobar" sipTest.opts.CookieSecret = "foobar"
sip_test.opts.ClientID = "bazquux" sipTest.opts.ClientID = "bazquux"
sip_test.opts.ClientSecret = "xyzzyplugh" sipTest.opts.ClientSecret = "xyzzyplugh"
sip_test.opts.SkipProviderButton = skipProvider sipTest.opts.SkipProviderButton = skipProvider
sip_test.opts.Validate() sipTest.opts.Validate()
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool {
return true return true
}) })
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
return &sip_test return &sipTest
} }
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sip_test.proxy.ServeHTTP(rw, req) sipTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String() return rw.Code, rw.Body.String()
} }
func TestSignInPageIncludesTargetRedirect(t *testing.T) { func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sip_test := NewSignInPageTest(false) sipTest := NewSignInPageTest(false)
const endpoint = "/some/random/endpoint" const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint) code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 403, code) assert.Equal(t, 403, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body) match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body) signInRedirectPattern + "\nBody:\n" + body)
@ -414,11 +412,11 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) {
} }
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sip_test := NewSignInPageTest(false) sipTest := NewSignInPageTest(false)
code, body := sip_test.GetEndpoint("/oauth2/sign_in") code, body := sipTest.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body) match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body) signInRedirectPattern + "\nBody:\n" + body)
@ -429,13 +427,13 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
} }
func TestSignInPageSkipProvider(t *testing.T) { func TestSignInPageSkipProvider(t *testing.T) {
sip_test := NewSignInPageTest(true) sipTest := NewSignInPageTest(true)
const endpoint = "/some/random/endpoint" const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint) code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code) assert.Equal(t, 302, code)
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body) signInSkipProvider + "\nBody:\n" + body)
@ -443,13 +441,13 @@ func TestSignInPageSkipProvider(t *testing.T) {
} }
func TestSignInPageSkipProviderDirect(t *testing.T) { func TestSignInPageSkipProviderDirect(t *testing.T) {
sip_test := NewSignInPageTest(true) sipTest := NewSignInPageTest(true)
const endpoint = "/sign_in" const endpoint = "/sign_in"
code, body := sip_test.GetEndpoint(endpoint) code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code) assert.Equal(t, 302, code)
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body) signInSkipProvider + "\nBody:\n" + body)
@ -462,45 +460,45 @@ type ProcessCookieTest struct {
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
req *http.Request req *http.Request
provider TestProvider provider TestProvider
response_code int responseCode int
validate_user bool validateUser bool
} }
type ProcessCookieTestOpts struct { type ProcessCookieTestOpts struct {
provider_validate_cookie_response bool providerValidateCookieResponse bool
} }
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
var pc_test ProcessCookieTest var pcTest ProcessCookieTest
pc_test.opts = NewOptions() pcTest.opts = NewOptions()
pc_test.opts.ClientID = "bazquux" pcTest.opts.ClientID = "bazquux"
pc_test.opts.ClientSecret = "xyzzyplugh" pcTest.opts.ClientSecret = "xyzzyplugh"
pc_test.opts.CookieSecret = "0123456789abcdefabcd" pcTest.opts.CookieSecret = "0123456789abcdefabcd"
// First, set the CookieRefresh option so proxy.AesCipher is created, // First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token. // needed to encrypt the access_token.
pc_test.opts.CookieRefresh = time.Hour pcTest.opts.CookieRefresh = time.Hour
pc_test.opts.Validate() pcTest.opts.Validate()
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pc_test.validate_user return pcTest.validateUser
}) })
pc_test.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: opts.provider_validate_cookie_response, ValidToken: opts.providerValidateCookieResponse,
} }
// Now, zero-out proxy.CookieRefresh for the cases that don't involve // Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation. // access_token validation.
pc_test.proxy.CookieRefresh = time.Duration(0) pcTest.proxy.CookieRefresh = time.Duration(0)
pc_test.rw = httptest.NewRecorder() pcTest.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pc_test.validate_user = true pcTest.validateUser = true
return &pc_test return &pcTest
} }
func NewProcessCookieTestWithDefaults() *ProcessCookieTest { func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{ return NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: true, providerValidateCookieResponse: true,
}) })
} }
@ -522,12 +520,12 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.
} }
func TestLoadCookiedSession(t *testing.T) { func TestLoadCookiedSession(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "michael.bland", session.User) assert.Equal(t, "michael.bland", session.User)
@ -535,9 +533,9 @@ func TestLoadCookiedSession(t *testing.T) {
} }
func TestProcessCookieNoCookieError(t *testing.T) { func TestProcessCookieNoCookieError(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil { if session != nil {
t.Errorf("expected nil session. got %#v", session) t.Errorf("expected nil session. got %#v", session)
@ -545,14 +543,14 @@ func TestProcessCookieNoCookieError(t *testing.T) {
} }
func TestProcessCookieRefreshNotSet(t *testing.T) { func TestProcessCookieRefreshNotSet(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour) reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, age, err := pc_test.LoadCookiedSession() session, age, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
if age < time.Duration(-2)*time.Hour { if age < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", age) t.Errorf("cookie too young %v", age)
@ -561,13 +559,13 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
} }
func TestProcessCookieFailIfCookieExpired(t *testing.T) { func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expected nil session %#v", session) t.Errorf("expected nil session %#v", session)
@ -575,14 +573,14 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
} }
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
pc_test.proxy.CookieRefresh = time.Hour pcTest.proxy.CookieRefresh = time.Hour
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expected nil session %#v", session) t.Errorf("expected nil session %#v", session)
@ -590,10 +588,10 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
} }
func NewAuthOnlyEndpointTest() *ProcessCookieTest { func NewAuthOnlyEndpointTest() *ProcessCookieTest {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
return pc_test return pcTest
} }
func TestAuthOnlyEndpointAccepted(t *testing.T) { func TestAuthOnlyEndpointAccepted(t *testing.T) {
@ -636,7 +634,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
startSession := &providers.SessionState{ startSession := &providers.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession, time.Now())
test.validate_user = false test.validateUser = false
test.proxy.ServeHTTP(test.rw, test.req) test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code) assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
@ -645,33 +643,33 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
} }
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
var pc_test ProcessCookieTest var pcTest ProcessCookieTest
pc_test.opts = NewOptions() pcTest.opts = NewOptions()
pc_test.opts.SetXAuthRequest = true pcTest.opts.SetXAuthRequest = true
pc_test.opts.Validate() pcTest.opts.Validate()
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pc_test.validate_user return pcTest.validateUser
}) })
pc_test.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: true, ValidToken: true,
} }
pc_test.validate_user = true pcTest.validateUser = true
pc_test.rw = httptest.NewRecorder() pcTest.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &providers.SessionState{ startSession := &providers.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
pc_test.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pc_test.rw.Code) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0]) assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0])
assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0])
} }
func TestAuthSkippedForPreflightRequests(t *testing.T) { func TestAuthSkippedForPreflightRequests(t *testing.T) {
@ -689,8 +687,8 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
opts.SkipAuthPreflight = true opts.SkipAuthPreflight = true
opts.Validate() opts.Validate()
upstream_url, _ := url.Parse(upstream.URL) upstreamURL, _ := url.Parse(upstream.URL)
opts.provider = NewTestProvider(upstream_url, "") opts.provider = NewTestProvider(upstreamURL, "")
proxy := NewOAuthProxy(opts, func(string) bool { return false }) proxy := NewOAuthProxy(opts, func(string) bool { return false })
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -723,7 +721,7 @@ func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Req
type SignatureTest struct { type SignatureTest struct {
opts *Options opts *Options
upstream *httptest.Server upstream *httptest.Server
upstream_host string upstreamHost string
provider *httptest.Server provider *httptest.Server
header http.Header header http.Header
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
@ -740,20 +738,20 @@ func NewSignatureTest() *SignatureTest {
authenticator := &SignatureAuthenticator{} authenticator := &SignatureAuthenticator{}
upstream := httptest.NewServer( upstream := httptest.NewServer(
http.HandlerFunc(authenticator.Authenticate)) http.HandlerFunc(authenticator.Authenticate))
upstream_url, _ := url.Parse(upstream.URL) upstreamURL, _ := url.Parse(upstream.URL)
opts.Upstreams = append(opts.Upstreams, upstream.URL) opts.Upstreams = append(opts.Upstreams, upstream.URL)
providerHandler := func(w http.ResponseWriter, r *http.Request) { providerHandler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"access_token": "my_auth_token"}`)) w.Write([]byte(`{"access_token": "my_auth_token"}`))
} }
provider := httptest.NewServer(http.HandlerFunc(providerHandler)) provider := httptest.NewServer(http.HandlerFunc(providerHandler))
provider_url, _ := url.Parse(provider.URL) providerURL, _ := url.Parse(provider.URL)
opts.provider = NewTestProvider(provider_url, "mbland@acm.org") opts.provider = NewTestProvider(providerURL, "mbland@acm.org")
return &SignatureTest{ return &SignatureTest{
opts, opts,
upstream, upstream,
upstream_url.Host, upstreamURL.Host,
provider, provider,
make(http.Header), make(http.Header),
httptest.NewRecorder(), httptest.NewRecorder(),

View File

@ -13,16 +13,17 @@ import (
"strings" "strings"
"time" "time"
"github.com/bitly/oauth2_proxy/providers"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers"
) )
// Configuration Options that can be set by Command Line Flag, or Config File // Options holds Configuration Options that can be set by Command Line Flag,
// or Config File
type Options struct { type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"`
HttpAddress string `flag:"http-address" cfg:"http_address"` HTTPAddress string `flag:"http-address" cfg:"http_address"`
HttpsAddress string `flag:"https-address" cfg:"https_address"` HTTPSAddress string `flag:"https-address" cfg:"https_address"`
RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` RedirectURL string `flag:"redirect-url" cfg:"redirect_url"`
ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
@ -48,7 +49,7 @@ type Options struct {
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"`
CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
Upstreams []string `flag:"upstream" cfg:"upstreams"` Upstreams []string `flag:"upstream" cfg:"upstreams"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
@ -88,20 +89,22 @@ type Options struct {
oidcVerifier *oidc.IDTokenVerifier oidcVerifier *oidc.IDTokenVerifier
} }
// SignatureData holds hmacauth signature hash and key
type SignatureData struct { type SignatureData struct {
hash crypto.Hash hash crypto.Hash
key string key string
} }
// NewOptions constructs a new Options with defaulted values
func NewOptions() *Options { func NewOptions() *Options {
return &Options{ return &Options{
ProxyPrefix: "/oauth2", ProxyPrefix: "/oauth2",
HttpAddress: "127.0.0.1:4180", HTTPAddress: "127.0.0.1:4180",
HttpsAddress: ":443", HTTPSAddress: ":443",
DisplayHtpasswdForm: true, DisplayHtpasswdForm: true,
CookieName: "_oauth2_proxy", CookieName: "_oauth2_proxy",
CookieSecure: true, CookieSecure: true,
CookieHttpOnly: true, CookieHTTPOnly: true,
CookieExpire: time.Duration(168) * time.Hour, CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(0), CookieRefresh: time.Duration(0),
SetXAuthRequest: false, SetXAuthRequest: false,
@ -116,15 +119,17 @@ func NewOptions() *Options {
} }
} }
func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) {
parsed, err := url.Parse(to_parse) parsed, err := url.Parse(toParse)
if err != nil { if err != nil {
return nil, append(msgs, fmt.Sprintf( return nil, append(msgs, fmt.Sprintf(
"error parsing %s-url=%q %s", urltype, to_parse, err)) "error parsing %s-url=%q %s", urltype, toParse, err))
} }
return parsed, msgs return parsed, msgs
} }
// Validate checks that required options are set and validates those that they
// are of the correct format
func (o *Options) Validate() error { func (o *Options) Validate() error {
if o.SSLInsecureSkipVerify { if o.SSLInsecureSkipVerify {
// TODO: Accept a certificate bundle. // TODO: Accept a certificate bundle.
@ -190,17 +195,17 @@ func (o *Options) Validate() error {
msgs = parseProviderInfo(o, msgs) msgs = parseProviderInfo(o, msgs)
if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
valid_cookie_secret_size := false validCookieSecretSize := false
for _, i := range []int{16, 24, 32} { for _, i := range []int{16, 24, 32} {
if len(secretBytes(o.CookieSecret)) == i { if len(secretBytes(o.CookieSecret)) == i {
valid_cookie_secret_size = true validCookieSecretSize = true
} }
} }
var decoded bool var decoded bool
if string(secretBytes(o.CookieSecret)) != o.CookieSecret { if string(secretBytes(o.CookieSecret)) != o.CookieSecret {
decoded = true decoded = true
} }
if valid_cookie_secret_size == false { if validCookieSecretSize == false {
var suffix string var suffix string
if decoded { if decoded {
suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret)
@ -294,12 +299,13 @@ func parseSignatureKey(o *Options, msgs []string) []string {
} }
algorithm, secretKey := components[0], components[1] algorithm, secretKey := components[0], components[1]
if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { var hash crypto.Hash
var err error
if hash, err = hmacauth.DigestNameToCryptoHash(algorithm); err != nil {
return append(msgs, "unsupported signature hash algorithm: "+ return append(msgs, "unsupported signature hash algorithm: "+
o.SignatureKey) o.SignatureKey)
} else {
o.signatureData = &SignatureData{hash, secretKey}
} }
o.signatureData = &SignatureData{hash, secretKey}
return msgs return msgs
} }

View File

@ -88,9 +88,9 @@ func TestProxyURLs(t *testing.T) {
o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081")
assert.Equal(t, nil, o.Validate()) assert.Equal(t, nil, o.Validate())
expected := []*url.URL{ expected := []*url.URL{
&url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, {Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
// note the '/' was added // note the '/' was added
&url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, {Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
} }
assert.Equal(t, expected, o.proxyURLs) assert.Equal(t, expected, o.proxyURLs)
} }

View File

@ -3,18 +3,21 @@ package providers
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/bitly/go-simplejson"
"github.com/bitly/oauth2_proxy/api"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api"
) )
// AzureProvider represents an Azure based Identity Provider
type AzureProvider struct { type AzureProvider struct {
*ProviderData *ProviderData
Tenant string Tenant string
} }
// NewAzureProvider initiates a new AzureProvider
func NewAzureProvider(p *ProviderData) *AzureProvider { func NewAzureProvider(p *ProviderData) *AzureProvider {
p.ProviderName = "Azure" p.ProviderName = "Azure"
@ -39,6 +42,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
return &AzureProvider{ProviderData: p} return &AzureProvider{ProviderData: p}
} }
// Configure defaults the AzureProvider configuration options
func (p *AzureProvider) Configure(tenant string) { func (p *AzureProvider) Configure(tenant string) {
p.Tenant = tenant p.Tenant = tenant
if tenant == "" { if tenant == "" {
@ -60,9 +64,9 @@ func (p *AzureProvider) Configure(tenant string) {
} }
} }
func getAzureHeader(access_token string) http.Header { func getAzureHeader(accessToken string) http.Header {
header := make(http.Header) header := make(http.Header)
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header return header
} }
@ -83,6 +87,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
return email, err return email, err
} }
// GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
var email string var email string
var err error var err error

View File

@ -110,8 +110,7 @@ func testAzureBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL if r.URL.Path != path || r.URL.RawQuery != query {
if url.Path != path || url.RawQuery != query {
w.WriteHeader(404) w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403) w.WriteHeader(403)

View File

@ -6,13 +6,15 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
) )
// FacebookProvider represents an Facebook based Identity Provider
type FacebookProvider struct { type FacebookProvider struct {
*ProviderData *ProviderData
} }
// NewFacebookProvider initiates a new FacebookProvider
func NewFacebookProvider(p *ProviderData) *FacebookProvider { func NewFacebookProvider(p *ProviderData) *FacebookProvider {
p.ProviderName = "Facebook" p.ProviderName = "Facebook"
if p.LoginURL.String() == "" { if p.LoginURL.String() == "" {
@ -43,14 +45,15 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
return &FacebookProvider{ProviderData: p} return &FacebookProvider{ProviderData: p}
} }
func getFacebookHeader(access_token string) http.Header { func getFacebookHeader(accessToken string) http.Header {
header := make(http.Header) header := make(http.Header)
header.Set("Accept", "application/json") header.Set("Accept", "application/json")
header.Set("x-li-format", "json") header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header return header
} }
// GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
@ -65,7 +68,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
Email string Email string
} }
var r result var r result
err = api.RequestJson(req, &r) err = api.RequestJSON(req, &r)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -75,6 +78,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
return r.Email, nil return r.Email, nil
} }
// ValidateSessionState validates the AccessToken
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
} }

View File

@ -12,12 +12,14 @@ import (
"strings" "strings"
) )
// GitHubProvider represents an GitHub based Identity Provider
type GitHubProvider struct { type GitHubProvider struct {
*ProviderData *ProviderData
Org string Org string
Team string Team string
} }
// NewGitHubProvider initiates a new GitHubProvider
func NewGitHubProvider(p *ProviderData) *GitHubProvider { func NewGitHubProvider(p *ProviderData) *GitHubProvider {
p.ProviderName = "GitHub" p.ProviderName = "GitHub"
if p.LoginURL == nil || p.LoginURL.String() == "" { if p.LoginURL == nil || p.LoginURL.String() == "" {
@ -47,6 +49,8 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider {
} }
return &GitHubProvider{ProviderData: p} return &GitHubProvider{ProviderData: p}
} }
// SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
func (p *GitHubProvider) SetOrgTeam(org, team string) { func (p *GitHubProvider) SetOrgTeam(org, team string) {
p.Org = org p.Org = org
p.Team = team p.Team = team
@ -106,7 +110,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
} }
orgs = append(orgs, op...) orgs = append(orgs, op...)
pn += 1 pn++
} }
var presentOrgs []string var presentOrgs []string
@ -186,7 +190,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams)
} else { } else {
var allOrgs []string var allOrgs []string
for org, _ := range presentOrgs { for org := range presentOrgs {
allOrgs = append(allOrgs, org) allOrgs = append(allOrgs, org)
} }
log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs)
@ -194,6 +198,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
return false, nil return false, nil
} }
// GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
var emails []struct { var emails []struct {
@ -251,6 +256,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
return "", nil return "", nil
} }
// GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
var user struct { var user struct {
Login string `json:"login"` Login string `json:"login"`

View File

@ -29,19 +29,18 @@ func testGitHubProvider(hostname string) *GitHubProvider {
func testGitHubBackend(payload []string) *httptest.Server { func testGitHubBackend(payload []string) *httptest.Server {
pathToQueryMap := map[string][]string{ pathToQueryMap := map[string][]string{
"/user": []string{""}, "/user": {""},
"/user/emails": []string{""}, "/user/emails": {""},
"/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, "/user/orgs": {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"},
} }
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL query, ok := pathToQueryMap[r.URL.Path]
query, ok := pathToQueryMap[url.Path]
validQuery := false validQuery := false
index := 0 index := 0
for i, q := range query { for i, q := range query {
if q == url.RawQuery { if q == r.URL.RawQuery {
validQuery = true validQuery = true
index = i index = i
} }

View File

@ -5,13 +5,15 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
) )
// GitLabProvider represents an GitLab based Identity Provider
type GitLabProvider struct { type GitLabProvider struct {
*ProviderData *ProviderData
} }
// NewGitLabProvider initiates a new GitLabProvider
func NewGitLabProvider(p *ProviderData) *GitLabProvider { func NewGitLabProvider(p *ProviderData) *GitLabProvider {
p.ProviderName = "GitLab" p.ProviderName = "GitLab"
if p.LoginURL == nil || p.LoginURL.String() == "" { if p.LoginURL == nil || p.LoginURL.String() == "" {
@ -41,6 +43,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
return &GitLabProvider{ProviderData: p} return &GitLabProvider{ProviderData: p}
} }
// GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
req, err := http.NewRequest("GET", req, err := http.NewRequest("GET",

View File

@ -33,8 +33,7 @@ func testGitLabBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL if r.URL.Path != path || r.URL.RawQuery != query {
if url.Path != path || url.RawQuery != query {
w.WriteHeader(404) w.WriteHeader(404)
} else { } else {
w.WriteHeader(200) w.WriteHeader(200)
@ -87,8 +86,8 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
@ -102,8 +101,8 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitLabBackend("unused payload") b := testGitLabBackend("unused payload")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host) p := testGitLabProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
@ -118,8 +117,8 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitLabBackend("{\"foo\": \"bar\"}") b := testGitLabBackend("{\"foo\": \"bar\"}")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)

View File

@ -20,6 +20,7 @@ import (
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
) )
// GoogleProvider represents an Google based Identity Provider
type GoogleProvider struct { type GoogleProvider struct {
*ProviderData *ProviderData
RedeemRefreshURL *url.URL RedeemRefreshURL *url.URL
@ -28,6 +29,7 @@ type GoogleProvider struct {
GroupValidator func(string) bool GroupValidator func(string) bool
} }
// NewGoogleProvider initiates a new GoogleProvider
func NewGoogleProvider(p *ProviderData) *GoogleProvider { func NewGoogleProvider(p *ProviderData) *GoogleProvider {
p.ProviderName = "Google" p.ProviderName = "Google"
if p.LoginURL.String() == "" { if p.LoginURL.String() == "" {
@ -62,7 +64,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
} }
} }
func emailFromIdToken(idToken string) (string, error) { func emailFromIDToken(idToken string) (string, error) {
// id_token is a base64 encode ID token payload // id_token is a base64 encode ID token payload
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
@ -90,6 +92,7 @@ func emailFromIdToken(idToken string) (string, error) {
return email.Email, nil return email.Email, nil
} }
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
@ -129,14 +132,14 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &jsonResponse)
if err != nil { if err != nil {
return return
} }
var email string var email string
email, err = emailFromIdToken(jsonResponse.IdToken) email, err = emailFromIDToken(jsonResponse.IDToken)
if err != nil { if err != nil {
return return
} }
@ -249,6 +252,8 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
return p.GroupValidator(email) return p.GroupValidator(email)
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil

View File

@ -81,7 +81,7 @@ type redeemResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
func TestGoogleProviderGetEmailAddress(t *testing.T) { func TestGoogleProviderGetEmailAddress(t *testing.T) {
@ -90,7 +90,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
AccessToken: "a1234", AccessToken: "a1234",
ExpiresIn: 10, ExpiresIn: 10,
RefreshToken: "refresh12345", RefreshToken: "refresh12345",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server
@ -127,7 +127,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{ body, err := json.Marshal(redeemResponse{
AccessToken: "a1234", AccessToken: "a1234",
IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, IDToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server
@ -146,7 +146,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
body, err := json.Marshal(redeemResponse{ body, err := json.Marshal(redeemResponse{
AccessToken: "a1234", AccessToken: "a1234",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server
@ -165,7 +165,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{ body, err := json.Marshal(redeemResponse{
AccessToken: "a1234", AccessToken: "a1234",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server

View File

@ -6,7 +6,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
) )
// stripToken is a helper function to obfuscate "access_token" // stripToken is a helper function to obfuscate "access_token"
@ -46,13 +46,13 @@ func stripParam(param, endpoint string) string {
} }
// validateToken returns true if token is valid // validateToken returns true if token is valid
func validateToken(p Provider, access_token string, header http.Header) bool { func validateToken(p Provider, accessToken string, header http.Header) bool {
if access_token == "" || p.Data().ValidateURL == nil { if accessToken == "" || p.Data().ValidateURL == nil {
return false return false
} }
endpoint := p.Data().ValidateURL.String() endpoint := p.Data().ValidateURL.String()
if len(header) == 0 { if len(header) == 0 {
params := url.Values{"access_token": {access_token}} params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode() endpoint = endpoint + "?" + params.Encode()
} }
resp, err := api.RequestUnparsedResponse(endpoint, header) resp, err := api.RequestUnparsedResponse(endpoint, header)
@ -72,8 +72,3 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
return false return false
} }
func updateURL(url *url.URL, hostname string) {
url.Scheme = "http"
url.Host = hostname
}

View File

@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func updateURL(url *url.URL, hostname string) {
url.Scheme = "http"
url.Host = hostname
}
type ValidateSessionStateTestProvider struct { type ValidateSessionStateTestProvider struct {
*ProviderData *ProviderData
} }
@ -26,27 +31,27 @@ func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState
type ValidateSessionStateTest struct { type ValidateSessionStateTest struct {
backend *httptest.Server backend *httptest.Server
response_code int responseCode int
provider *ValidateSessionStateTestProvider provider *ValidateSessionStateTestProvider
header http.Header header http.Header
} }
func NewValidateSessionStateTest() *ValidateSessionStateTest { func NewValidateSessionStateTest() *ValidateSessionStateTest {
var vt_test ValidateSessionStateTest var vtTest ValidateSessionStateTest
vt_test.backend = httptest.NewServer( vtTest.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/tokeninfo" { if r.URL.Path != "/oauth/tokeninfo" {
w.WriteHeader(500) w.WriteHeader(500)
w.Write([]byte("unknown URL")) w.Write([]byte("unknown URL"))
} }
token_param := r.FormValue("access_token") tokenParam := r.FormValue("access_token")
if token_param == "" { if tokenParam == "" {
missing := false missing := false
received_headers := r.Header receivedHeaders := r.Header
for k, _ := range vt_test.header { for k := range vtTest.header {
received := received_headers.Get(k) received := receivedHeaders.Get(k)
expected := vt_test.header.Get(k) expected := vtTest.header.Get(k)
if received == "" || received != expected { if received == "" || received != expected {
missing = true missing = true
} }
@ -56,68 +61,68 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest {
w.Write([]byte("no token param and missing or incorrect headers")) w.Write([]byte("no token param and missing or incorrect headers"))
} }
} }
w.WriteHeader(vt_test.response_code) w.WriteHeader(vtTest.responseCode)
w.Write([]byte("only code matters; contents disregarded")) w.Write([]byte("only code matters; contents disregarded"))
})) }))
backend_url, _ := url.Parse(vt_test.backend.URL) backendURL, _ := url.Parse(vtTest.backend.URL)
vt_test.provider = &ValidateSessionStateTestProvider{ vtTest.provider = &ValidateSessionStateTestProvider{
ProviderData: &ProviderData{ ProviderData: &ProviderData{
ValidateURL: &url.URL{ ValidateURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: backend_url.Host, Host: backendURL.Host,
Path: "/oauth/tokeninfo", Path: "/oauth/tokeninfo",
}, },
}, },
} }
vt_test.response_code = 200 vtTest.responseCode = 200
return &vt_test return &vtTest
} }
func (vt_test *ValidateSessionStateTest) Close() { func (vtTest *ValidateSessionStateTest) Close() {
vt_test.backend.Close() vtTest.backend.Close()
} }
func TestValidateSessionStateValidToken(t *testing.T) { func TestValidateSessionStateValidToken(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
vt_test.header = make(http.Header) vtTest.header = make(http.Header)
vt_test.header.Set("Authorization", "Bearer foobar") vtTest.header.Set("Authorization", "Bearer foobar")
assert.Equal(t, true, assert.Equal(t, true,
validateToken(vt_test.provider, "foobar", vt_test.header)) validateToken(vtTest.provider, "foobar", vtTest.header))
} }
func TestValidateSessionStateEmptyToken(t *testing.T) { func TestValidateSessionStateEmptyToken(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "", nil))
} }
func TestValidateSessionStateEmptyValidateURL(t *testing.T) { func TestValidateSessionStateEmptyValidateURL(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
vt_test.provider.Data().ValidateURL = nil vtTest.provider.Data().ValidateURL = nil
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
// Close immediately to simulate a network failure // Close immediately to simulate a network failure
vt_test.Close() vtTest.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateExpiredToken(t *testing.T) { func TestValidateSessionStateExpiredToken(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
vt_test.response_code = 401 vtTest.responseCode = 401
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
} }
func TestStripTokenNotPresent(t *testing.T) { func TestStripTokenNotPresent(t *testing.T) {

View File

@ -6,13 +6,15 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
) )
// LinkedInProvider represents an LinkedIn based Identity Provider
type LinkedInProvider struct { type LinkedInProvider struct {
*ProviderData *ProviderData
} }
// NewLinkedInProvider initiates a new LinkedInProvider
func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
p.ProviderName = "LinkedIn" p.ProviderName = "LinkedIn"
if p.LoginURL.String() == "" { if p.LoginURL.String() == "" {
@ -39,14 +41,15 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
return &LinkedInProvider{ProviderData: p} return &LinkedInProvider{ProviderData: p}
} }
func getLinkedInHeader(access_token string) http.Header { func getLinkedInHeader(accessToken string) http.Header {
header := make(http.Header) header := make(http.Header)
header.Set("Accept", "application/json") header.Set("Accept", "application/json")
header.Set("x-li-format", "json") header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header return header
} }
// GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
@ -69,6 +72,7 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
return email, nil return email, nil
} }
// ValidateSessionState validates the AccessToken
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
} }

View File

@ -31,8 +31,7 @@ func testLinkedInBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL if r.URL.Path != path {
if url.Path != path {
w.WriteHeader(404) w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403) w.WriteHeader(403)
@ -95,8 +94,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
b := testLinkedInBackend(`"user@linkedin.com"`) b := testLinkedInBackend(`"user@linkedin.com"`)
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
@ -108,8 +107,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testLinkedInBackend("unused payload") b := testLinkedInBackend("unused payload")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
@ -124,8 +123,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testLinkedInBackend("{\"foo\": \"bar\"}") b := testLinkedInBackend("{\"foo\": \"bar\"}")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)

View File

@ -10,17 +10,20 @@ import (
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
) )
// OIDCProvider represents an OIDC based Identity Provider
type OIDCProvider struct { type OIDCProvider struct {
*ProviderData *ProviderData
Verifier *oidc.IDTokenVerifier Verifier *oidc.IDTokenVerifier
} }
// NewOIDCProvider initiates a new OIDCProvider
func NewOIDCProvider(p *ProviderData) *OIDCProvider { func NewOIDCProvider(p *ProviderData) *OIDCProvider {
p.ProviderName = "OpenID Connect" p.ProviderName = "OpenID Connect"
return &OIDCProvider{ProviderData: p} return &OIDCProvider{ProviderData: p}
} }
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
ctx := context.Background() ctx := context.Background()
c := oauth2.Config{ c := oauth2.Config{
@ -73,6 +76,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
return return
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
//
// WARNGING: This implementation is broken and does not check with the upstream
// OIDC provider before refreshing the session
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil

View File

@ -4,6 +4,8 @@ import (
"net/url" "net/url"
) )
// ProviderData contains information required to configure all implementations
// of OAuth2 providers
type ProviderData struct { type ProviderData struct {
ProviderName string ProviderName string
ClientID string ClientID string
@ -17,4 +19,5 @@ type ProviderData struct {
ApprovalPrompt string ApprovalPrompt string
} }
// Data returns the ProviderData
func (p *ProviderData) Data() *ProviderData { return p } func (p *ProviderData) Data() *ProviderData { return p }

View File

@ -9,9 +9,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
) )
// Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
@ -102,6 +103,7 @@ func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *Session
return DecodeSessionState(v, c) return DecodeSessionState(v, c)
} }
// GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
@ -117,11 +119,13 @@ func (p *ProviderData) ValidateGroup(email string) bool {
return true return true
} }
// ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *SessionState) bool { func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, nil) return validateToken(p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return false, nil return false, nil
} }

View File

@ -1,9 +1,10 @@
package providers package providers
import ( import (
"github.com/bitly/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
) )
// Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*SessionState) (string, error) GetEmailAddress(*SessionState) (string, error)
@ -17,6 +18,7 @@ type Provider interface {
CookieForSession(*SessionState, *cookie.Cipher) (string, error) CookieForSession(*SessionState, *cookie.Cipher) (string, error)
} }
// New provides a new Provider based on the configured provider string
func New(provider string, p *ProviderData) Provider { func New(provider string, p *ProviderData) Provider {
switch provider { switch provider {
case "linkedin": case "linkedin":

View File

@ -6,9 +6,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/bitly/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
) )
// SessionState is used to store information about the currently authenticated user session
type SessionState struct { type SessionState struct {
AccessToken string AccessToken string
ExpiresOn time.Time ExpiresOn time.Time
@ -17,6 +18,7 @@ type SessionState struct {
User string User string
} }
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool { func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true return true
@ -24,6 +26,7 @@ func (s *SessionState) IsExpired() bool {
return false return false
} }
// String constructs a summary of the session state
func (s *SessionState) String() string { func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.accountInfo()) o := fmt.Sprintf("Session{%s", s.accountInfo())
if s.AccessToken != "" { if s.AccessToken != "" {
@ -38,6 +41,7 @@ func (s *SessionState) String() string {
return o + "}" return o + "}"
} }
// EncodeSessionState returns string representation of the current session
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" { if c == nil || s.AccessToken == "" {
return s.accountInfo(), nil return s.accountInfo(), nil
@ -49,6 +53,7 @@ func (s *SessionState) accountInfo() string {
return fmt.Sprintf("email:%s user:%s", s.Email, s.User) return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
} }
// EncryptedString encrypts the session state into a cookie string
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
var err error var err error
if c == nil { if c == nil {
@ -84,6 +89,7 @@ func decodeSessionStatePlain(v string) (s *SessionState, err error) {
return &SessionState{User: user, Email: email}, nil return &SessionState{User: user, Email: email}, nil
} }
// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
if c == nil { if c == nil {
return decodeSessionStatePlain(v) return decodeSessionStatePlain(v)

View File

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/bitly/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -4,13 +4,16 @@ import (
"strings" "strings"
) )
// StringArray is a type alias for a slice of strings
type StringArray []string type StringArray []string
// Set appends a string to the StringArray
func (a *StringArray) Set(s string) error { func (a *StringArray) Set(s string) error {
*a = append(*a, s) *a = append(*a, s)
return nil return nil
} }
// String joins elements of the StringArray into a single comma separated string
func (a *StringArray) String() string { func (a *StringArray) String() string {
return strings.Join(*a, ",") return strings.Join(*a, ",")
} }

View File

@ -142,7 +142,7 @@ func getTemplates() *template.Template {
<footer> <footer>
{{ if eq .Footer "-" }} {{ if eq .Footer "-" }}
{{ else if eq .Footer ""}} {{ else if eq .Footer ""}}
Secured with <a href="https://github.com/bitly/oauth2_proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}} Secured with <a href="https://github.com/pusher/oauth2_proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}}
{{ else }} {{ else }}
{{.Footer}} {{.Footer}}
{{ end }} {{ end }}

14
test.sh
View File

@ -1,14 +0,0 @@
#!/bin/bash
EXIT_CODE=0
echo "gofmt"
diff -u <(echo -n) <(gofmt -d $(find . -type f -name '*.go' -not -path "./vendor/*")) || EXIT_CODE=1
for pkg in $(go list ./... | grep -v '/vendor/' ); do
echo "testing $pkg"
echo "go vet $pkg"
go vet "$pkg" || EXIT_CODE=1
echo "go test -v $pkg"
go test -v -timeout 90s "$pkg" || EXIT_CODE=1
echo "go test -v -race $pkg"
GOMAXPROCS=4 go test -v -timeout 90s0s -race "$pkg" || EXIT_CODE=1
done
exit $EXIT_CODE

View File

@ -10,11 +10,13 @@ import (
"unsafe" "unsafe"
) )
// UserMap holds information from the authenticated emails file
type UserMap struct { type UserMap struct {
usersFile string usersFile string
m unsafe.Pointer m unsafe.Pointer
} }
// NewUserMap parses the authenticated emails file into a new UserMap
func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
um := &UserMap{usersFile: usersFile} um := &UserMap{usersFile: usersFile}
m := make(map[string]bool) m := make(map[string]bool)
@ -30,23 +32,26 @@ func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
return um return um
} }
// IsValid checks if an email is allowed
func (um *UserMap) IsValid(email string) (result bool) { func (um *UserMap) IsValid(email string) (result bool) {
m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) m := *(*map[string]bool)(atomic.LoadPointer(&um.m))
_, result = m[email] _, result = m[email]
return return
} }
// LoadAuthenticatedEmailsFile loads the authenticated emails file from disk
// and parses the contents as CSV
func (um *UserMap) LoadAuthenticatedEmailsFile() { func (um *UserMap) LoadAuthenticatedEmailsFile() {
r, err := os.Open(um.usersFile) r, err := os.Open(um.usersFile)
if err != nil { if err != nil {
log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
} }
defer r.Close() defer r.Close()
csv_reader := csv.NewReader(r) csvReader := csv.NewReader(r)
csv_reader.Comma = ',' csvReader.Comma = ','
csv_reader.Comment = '#' csvReader.Comment = '#'
csv_reader.TrimLeadingSpace = true csvReader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll() records, err := csvReader.ReadAll()
if err != nil { if err != nil {
log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
return return
@ -91,6 +96,7 @@ func newValidatorImpl(domains []string, usersFile string,
return validator return validator
} }
// NewValidator constructs a function to validate email addresses
func NewValidator(domains []string, usersFile string) func(string) bool { func NewValidator(domains []string, usersFile string) func(string) bool {
return newValidatorImpl(domains, usersFile, nil, func() {}) return newValidatorImpl(domains, usersFile, nil, func() {})
} }

View File

@ -8,15 +8,15 @@ import (
) )
type ValidatorTest struct { type ValidatorTest struct {
auth_email_file *os.File authEmailFile *os.File
done chan bool done chan bool
update_seen bool updateSeen bool
} }
func NewValidatorTest(t *testing.T) *ValidatorTest { func NewValidatorTest(t *testing.T) *ValidatorTest {
vt := &ValidatorTest{} vt := &ValidatorTest{}
var err error var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file: " + err.Error()) t.Fatal("failed to create temp file: " + err.Error())
} }
@ -26,27 +26,27 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
func (vt *ValidatorTest) TearDown() { func (vt *ValidatorTest) TearDown() {
vt.done <- true vt.done <- true
os.Remove(vt.auth_email_file.Name()) os.Remove(vt.authEmailFile.Name())
} }
func (vt *ValidatorTest) NewValidator(domains []string, func (vt *ValidatorTest) NewValidator(domains []string,
updated chan<- bool) func(string) bool { updated chan<- bool) func(string) bool {
return newValidatorImpl(domains, vt.auth_email_file.Name(), return newValidatorImpl(domains, vt.authEmailFile.Name(),
vt.done, func() { vt.done, func() {
if vt.update_seen == false { if vt.updateSeen == false {
updated <- true updated <- true
vt.update_seen = true vt.updateSeen = true
} }
}) })
} }
// This will close vt.auth_email_file. // This will close vt.authEmailFile.
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
defer vt.auth_email_file.Close() defer vt.authEmailFile.Close()
vt.auth_email_file.WriteString(strings.Join(emails, "\n")) vt.authEmailFile.WriteString(strings.Join(emails, "\n"))
if err := vt.auth_email_file.Close(); err != nil { if err := vt.authEmailFile.Close(); err != nil {
t.Fatal("failed to close temp file " + t.Fatal("failed to close temp file " +
vt.auth_email_file.Name() + ": " + err.Error()) vt.authEmailFile.Name() + ": " + err.Error())
} }
} }

View File

@ -12,18 +12,18 @@ import (
func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver(
t *testing.T, emails []string) { t *testing.T, emails []string) {
orig_file := vt.auth_email_file origFile := vt.authEmailFile
var err error var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file for copy: " + err.Error()) t.Fatal("failed to create temp file for copy: " + err.Error())
} }
vt.WriteEmails(t, emails) vt.WriteEmails(t, emails)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
if err != nil { if err != nil {
t.Fatal("failed to copy over temp file: " + err.Error()) t.Fatal("failed to copy over temp file: " + err.Error())
} }
vt.auth_email_file = orig_file vt.authEmailFile = origFile
} }
func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {

View File

@ -10,8 +10,8 @@ import (
func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
var err error var err error
vt.auth_email_file, err = os.OpenFile( vt.authEmailFile, err = os.OpenFile(
vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil { if err != nil {
t.Fatal("failed to re-open temp file for updates") t.Fatal("failed to re-open temp file for updates")
} }
@ -20,24 +20,24 @@ func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace(
t *testing.T, emails []string) { t *testing.T, emails []string) {
orig_file := vt.auth_email_file origFile := vt.authEmailFile
var err error var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file for rename and replace: " + t.Fatal("failed to create temp file for rename and replace: " +
err.Error()) err.Error())
} }
vt.WriteEmails(t, emails) vt.WriteEmails(t, emails)
moved_name := orig_file.Name() + "-moved" movedName := origFile.Name() + "-moved"
err = os.Rename(orig_file.Name(), moved_name) err = os.Rename(origFile.Name(), movedName)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
if err != nil { if err != nil {
t.Fatal("failed to rename and replace temp file: " + t.Fatal("failed to rename and replace temp file: " +
err.Error()) err.Error())
} }
vt.auth_email_file = orig_file vt.authEmailFile = origFile
os.Remove(moved_name) os.Remove(movedName)
} }
func TestValidatorOverwriteEmailListDirectly(t *testing.T) { func TestValidatorOverwriteEmailListDirectly(t *testing.T) {

View File

@ -1,3 +1,4 @@
package main package main
const VERSION = "2.2.1-alpha" // VERSION contains version information
var VERSION = "undefined"

View File

@ -8,16 +8,18 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"gopkg.in/fsnotify.v1" fsnotify "gopkg.in/fsnotify/fsnotify.v1"
) )
// WaitForReplacement waits for a file to exist on disk and then starts a watch
// for the file
func WaitForReplacement(filename string, op fsnotify.Op, func WaitForReplacement(filename string, op fsnotify.Op,
watcher *fsnotify.Watcher) { watcher *fsnotify.Watcher) {
const sleep_interval = 50 * time.Millisecond const sleepInterval = 50 * time.Millisecond
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
if op&fsnotify.Chmod != 0 { if op&fsnotify.Chmod != 0 {
time.Sleep(sleep_interval) time.Sleep(sleepInterval)
} }
for { for {
if _, err := os.Stat(filename); err == nil { if _, err := os.Stat(filename); err == nil {
@ -26,10 +28,11 @@ func WaitForReplacement(filename string, op fsnotify.Op,
return return
} }
} }
time.Sleep(sleep_interval) time.Sleep(sleepInterval)
} }
} }
// WatchForUpdates performs an action every time a file on disk is updated
func WatchForUpdates(filename string, done <-chan bool, action func()) { func WatchForUpdates(filename string, done <-chan bool, action func()) {
filename = filepath.Clean(filename) filename = filepath.Clean(filename)
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
@ -56,7 +59,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) {
} }
log.Printf("reloading after event: %s", event) log.Printf("reloading after event: %s", event)
action() action()
case err := <-watcher.Errors: case err = <-watcher.Errors:
log.Printf("error watching %s: %s", filename, err) log.Printf("error watching %s: %s", filename, err)
} }
} }