Merge pull request #7 from pusher/migration

Migration from Bitly to Pusher
This commit is contained in:
Joel Speed 2019-01-14 09:54:05 +00:00 committed by GitHub
commit e1f45dd941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 958 additions and 444 deletions

3
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,3 @@
# Default owner should be a Pusher cloud-team member unless overridden by later
# rules in this file
* @pusher/cloud-team

37
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,37 @@
<!--- Provide a general summary of the issue in the Title above -->
## Expected Behavior
<!--- If you're describing a bug, tell us what should happen -->
<!--- If you're suggesting a change/improvement, tell us how it should work -->
## Current Behavior
<!--- If describing a bug, tell us what happens instead of the expected behavior -->
<!--- If suggesting a change/improvement, explain the difference from current behavior -->
## Possible Solution
<!--- Not obligatory, but suggest a fix/reason for the bug, -->
<!--- or ideas how to implement the addition or change -->
## Steps to Reproduce (for bugs)
<!--- Provide a link to a live example, or an unambiguous set of steps to -->
<!--- reproduce this bug. Include code to reproduce, if relevant -->
1. <!--- Step 1 --->
2. <!--- Step 2 --->
3. <!--- Step 3 --->
4. <!--- Step 4 --->
## Context
<!--- How has this issue affected you? What are you trying to accomplish? -->
<!--- Providing context helps us come up with a solution that is most useful in the real world -->
## Your Environment
<!--- Include as many relevant details about the environment you experienced the bug in -->
- Version used:

25
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,25 @@
<!--- Provide a general summary of your changes in the Title above -->
## Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
## How Has This Been Tested?
<!--- Please describe in detail how you tested your changes. -->
<!--- Include details of your testing environment, and the tests you ran to -->
<!--- see how your change affects other areas of the code, etc. -->
## Checklist:
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
- [ ] My change requires a change to the documentation or CHANGELOG.
- [ ] I have updated the documentation/CHANGELOG accordingly.
- [ ] I have created a feature (non-master) branch for my PR.

2
.gitignore vendored
View File

@ -3,7 +3,7 @@ vendor
dist
.godeps
*.exe
.env
# Go.gitignore
# Compiled Object files, Static and Dynamic libs (Shared Objects)

View File

@ -1,12 +1,16 @@
language: go
go:
- 1.8.x
- 1.9.x
script:
- 1.10.x
install:
# Fetch dependencies
- wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64
- chmod +x dep
- ./dep ensure
- ./test.sh
- mv dep $GOPATH/bin/dep
script:
- ./configure
# Run tests
- make test
sudo: false
notifications:
email: false

23
CHANGELOG.md Normal file
View File

@ -0,0 +1,23 @@
# Vx.x.x (Pre-release)
## Changes since v2.2:
- Move automated build to debian base image
- Add Makefile
- Update CI to run `make test`
- Update Dockerfile to use `make clean oauth2_proxy`
- Update `VERSION` parameter to be set by `ldflags` from Git Status
- Remove lint and test scripts
- Remove Go v1.8.x from Travis CI testing
- Add CODEOWNERS file
- Add CONTRIBUTING guide
- Add Issue and Pull Request templates
- Add Dockerfile
- Fix fsnotify import
- Update README to reflect new repository ownership
- Update CI scripts to separate linting and testing
- Now using `gometalinter` for linting
- Move Go import path from `github.com/bitly/oauth2_proxy` to `github.com/pusher/oauth2_proxy`
- Repository forked on 27/11/18
- README updated to include note that this repository is forked
- CHANGLOG created to track changes to repository from original fork

22
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,22 @@
# Contributing
To develop on this project, please fork the repo and clone into your `$GOPATH`.
Dependencies are **not** checked in so please download those separately.
Download the dependencies using [`dep`](https://github.com/golang/dep).
```bash
cd $GOPATH/src/github.com # Create this directory if it doesn't exist
git clone git@github.com:<YOUR_FORK>/oauth2_proxy pusher/oauth2_proxy
make dep
```
## Pull Requests and Issues
We track bugs and issues using Github.
If you find a bug, please open an Issue.
If you want to fix a bug, please fork, create a feature branch, fix the bug and
open a PR back to this repo.
Please mention the open bug issue number within your PR if applicable.

16
Dockerfile Normal file
View File

@ -0,0 +1,16 @@
FROM golang:1.10 AS builder
WORKDIR /go/src/github.com/pusher/oauth2_proxy
COPY . .
# Fetch dependencies
RUN go get -u github.com/golang/dep/cmd/dep
RUN dep ensure --vendor-only
# Build image
RUN ./configure && make clean oauth2_proxy
# Copy binary to debian
FROM debian:stretch
COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy
ENTRYPOINT ["/bin/oauth2_proxy"]

67
Gopkg.lock generated
View File

@ -2,118 +2,149 @@
[[projects]]
digest = "1:b24249f5a5e6fbe1eddc94b25973172339ccabeadef4779274f3ed0167c18812"
name = "cloud.google.com/go"
packages = ["compute/metadata"]
pruneopts = ""
revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613"
version = "v0.16.0"
[[projects]]
digest = "1:289dd4d7abfb3ad2b5f728fbe9b1d5c1bf7d265a3eb9ef92869af1f7baba4c7a"
name = "github.com/BurntSushi/toml"
packages = ["."]
pruneopts = ""
revision = "b26d9c308763d68093482582cea63d69be07a0f0"
version = "v0.3.0"
[[projects]]
digest = "1:512883404c2a99156e410e9880e3bb35ecccc0c07c1159eb204b5f3ef3c431b3"
name = "github.com/bitly/go-simplejson"
packages = ["."]
pruneopts = ""
revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44"
version = "v0.5.0"
[[projects]]
branch = "v2"
digest = "1:e5a238f8fa890e529d7e493849bbae8988c9e70344e4630cc4f9a11b00516afb"
name = "github.com/coreos/go-oidc"
packages = ["."]
pruneopts = ""
revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8"
[[projects]]
digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b"
name = "github.com/davecgh/go-spew"
packages = ["spew"]
pruneopts = ""
revision = "346938d642f2ec3594ed81d874461961cd0faa76"
version = "v1.1.0"
[[projects]]
branch = "master"
digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4"
name = "github.com/golang/protobuf"
packages = ["proto"]
pruneopts = ""
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
[[projects]]
digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f"
name = "github.com/mbland/hmacauth"
packages = ["."]
pruneopts = ""
revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb"
version = "1.0.2"
[[projects]]
branch = "master"
digest = "1:9408fb9c637c103010e5147469c232ce6b68edc840879cc730a2a15918e6cae8"
name = "github.com/mreiferson/go-options"
packages = ["."]
pruneopts = ""
revision = "77551d20752b54535462404ad9d877ebdb26e53d"
[[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib"
packages = ["difflib"]
pruneopts = ""
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
branch = "master"
digest = "1:386e12afcfd8964907c92dffd106860c0dedd71dbefae14397b77b724a13343b"
name = "github.com/pquerna/cachecontrol"
packages = [
".",
"cacheobject"
"cacheobject",
]
pruneopts = ""
revision = "0dec1b30a0215bb68605dfc568e8855066c9202d"
[[projects]]
digest = "1:3926a4ec9a4ff1a072458451aa2d9b98acd059a45b38f7335d31e06c3d6a0159"
name = "github.com/stretchr/testify"
packages = ["assert"]
pruneopts = ""
revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0"
version = "v1.1.4"
[[projects]]
branch = "master"
digest = "1:f6a006d27619a4d93bf9b66fe1999b8c8d1fa62bdc63af14f10fbe6fcaa2aa1a"
name = "golang.org/x/crypto"
packages = [
"bcrypt",
"blowfish",
"ed25519",
"ed25519/internal/edwards25519"
"ed25519/internal/edwards25519",
]
pruneopts = ""
revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94"
[[projects]]
branch = "master"
digest = "1:130b1bec86c62e121967ee0c69d9c263dc2d3ffe6c7c9a82aca4071c4d068861"
name = "golang.org/x/net"
packages = [
"context",
"context/ctxhttp"
"context/ctxhttp",
]
pruneopts = ""
revision = "9dfe39835686865bff950a07b394c12a98ddc811"
[[projects]]
branch = "master"
digest = "1:4a61176e8386727e4847b21a5a2625ce56b9c518bc543a28226503e701265db0"
name = "golang.org/x/oauth2"
packages = [
".",
"google",
"internal",
"jws",
"jwt"
"jwt",
]
pruneopts = ""
revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402"
[[projects]]
branch = "master"
digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed"
name = "google.golang.org/api"
packages = [
"admin/directory/v1",
"gensupport",
"googleapi",
"googleapi/internal/uritemplates"
"googleapi/internal/uritemplates",
]
pruneopts = ""
revision = "8791354e7ab150705ede13637a18c1fcc16b62e8"
[[projects]]
digest = "1:934fb8966f303ede63aa405e2c8d7f0a427a05ea8df335dfdc1833dd4d40756f"
name = "google.golang.org/appengine"
packages = [
".",
@ -125,30 +156,48 @@
"internal/modules",
"internal/remote_api",
"internal/urlfetch",
"urlfetch"
"urlfetch",
]
pruneopts = ""
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
version = "v1.0.0"
[[projects]]
name = "gopkg.in/fsnotify.v1"
digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2"
name = "gopkg.in/fsnotify/fsnotify.v1"
packages = ["."]
pruneopts = ""
revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6"
version = "v1.2.11"
[[projects]]
digest = "1:be4ed0a2b15944dd777a663681a39260ed05f9c4e213017ed2e2255622c8820c"
name = "gopkg.in/square/go-jose.v2"
packages = [
".",
"cipher",
"json"
"json",
]
pruneopts = ""
revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1"
version = "v2.1.3"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "b502c41a61115d14d6379be26b0300f65d173bdad852f0170d387ebf2d7ec173"
input-imports = [
"github.com/BurntSushi/toml",
"github.com/bitly/go-simplejson",
"github.com/coreos/go-oidc",
"github.com/mbland/hmacauth",
"github.com/mreiferson/go-options",
"github.com/stretchr/testify/assert",
"golang.org/x/crypto/bcrypt",
"golang.org/x/oauth2",
"golang.org/x/oauth2/google",
"google.golang.org/api/admin/directory/v1",
"google.golang.org/api/googleapi",
"gopkg.in/fsnotify/fsnotify.v1",
]
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -3,10 +3,6 @@
# for detailed Gopkg.toml documentation.
#
[[constraint]]
name = "github.com/18F/hmacauth"
version = "~1.0.1"
[[constraint]]
name = "github.com/BurntSushi/toml"
version = "~0.3.0"
@ -36,7 +32,7 @@
name = "google.golang.org/api"
[[constraint]]
name = "gopkg.in/fsnotify.v1"
name = "gopkg.in/fsnotify/fsnotify.v1"
version = "~1.2.0"
[[constraint]]

56
Makefile Normal file
View File

@ -0,0 +1,56 @@
include .env
BINARY := oauth2_proxy
VERSION := $(shell git describe --always --long --dirty --tags 2>/dev/null || echo "undefined")
.NOTPARALLEL:
.PHONY: all
all: dep lint $(BINARY)
.PHONY: clean
clean:
rm -rf release
rm -f $(BINARY)
.PHONY: distclean
distclean: clean
rm -rf vendor
BIN_DIR := $(GOPATH)/bin
GOMETALINTER := $(BIN_DIR)/gometalinter
$(GOMETALINTER):
$(GO) get -u github.com/alecthomas/gometalinter
gometalinter --install %> /dev/null
.PHONY: lint
lint: $(GOMETALINTER)
$(GOMETALINTER) --vendor --disable-all \
--enable=vet \
--enable=vetshadow \
--enable=golint \
--enable=ineffassign \
--enable=goconst \
--enable=deadcode \
--enable=gofmt \
--enable=goimports \
--tests ./...
.PHONY: dep
dep:
$(DEP) ensure --vendor-only
.PHONY: build
build: clean $(BINARY)
$(BINARY):
$(GO) build -ldflags="-X main.VERSION=${VERSION}" -o $(BINARY) github.com/pusher/oauth2_proxy
.PHONY: test
test: dep lint
$(GO) test -v -race $(go list ./... | grep -v /vendor/)
.PHONY: release
release: lint test
mkdir release
GOOS=darwin GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-darwin-amd64 github.com/pusher/oauth2_proxy
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-linux-amd64 github.com/pusher/oauth2_proxy

146
README.md
View File

@ -1,11 +1,13 @@
oauth2_proxy
=================
# oauth2_proxy
A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others)
to validate accounts by email, domain or group.
[![Build Status](https://secure.travis-ci.org/bitly/oauth2_proxy.svg?branch=master)](http://travis-ci.org/bitly/oauth2_proxy)
**Note:** This repository was forked from [bitly/OAuth2_Proxy](https://github.com/bitly/oauth2_proxy) on 27/11/2018.
Versions v3.0.0 and up are from this fork and will have diverged from any changes in the original fork.
A list of changes can be seen in the [CHANGELOG](CHANGELOG.md).
[![Build Status](https://secure.travis-ci.org/pusher/oauth2_proxy.svg?branch=master)](http://travis-ci.org/pusher/oauth2_proxy)
![Sign In Page](https://cloud.githubusercontent.com/assets/45028/4970624/7feb7dd8-6886-11e4-93e0-c9904af44ea8.png)
@ -15,15 +17,24 @@ to validate accounts by email, domain or group.
## Installation
1. Download [Prebuilt Binary](https://github.com/bitly/oauth2_proxy/releases) (current release is `v2.2`) or build with `$ go get github.com/bitly/oauth2_proxy` which will put the binary in `$GOROOT/bin`
1. Choose how to deploy:
a. Download [Prebuilt Binary](https://github.com/pusher/oauth2_proxy/releases) (current release is `v2.2`)
b. Build with `$ go get github.com/pusher/oauth2_proxy` which will put the binary in `$GOROOT/bin`
c. Using the prebuilt docker image [quay.io/pusher/oauth2_proxy](https://quay.io/pusher/oauth2_proxy)
Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`.
```
sha256sum -c sha256sum.txt 2>&1 | grep OK
oauth2_proxy-2.3.linux-amd64: OK
```
2. Select a Provider and Register an OAuth Application with a Provider
3. Configure OAuth2 Proxy using config file, command line options, or environment variables
4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx)
2. Select a Provider and Register an OAuth Application with a Provider
3. Configure OAuth2 Proxy using config file, command line options, or environment variables
4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx)
## OAuth Provider Configuration
@ -31,12 +42,12 @@ You will need to register an OAuth application with a Provider (Google, GitHub o
Valid providers are :
* [Google](#google-auth-provider) *default*
* [Azure](#azure-auth-provider)
* [Facebook](#facebook-auth-provider)
* [GitHub](#github-auth-provider)
* [GitLab](#gitlab-auth-provider)
* [LinkedIn](#linkedin-auth-provider)
- [Google](#google-auth-provider) _default_
- [Azure](#azure-auth-provider)
- [Facebook](#facebook-auth-provider)
- [GitHub](#github-auth-provider)
- [GitLab](#gitlab-auth-provider)
- [LinkedIn](#linkedin-auth-provider)
The provider can be selected using the `provider` configuration value.
@ -44,61 +55,62 @@ The provider can be selected using the `provider` configuration value.
For Google, the registration steps are:
1. Create a new project: https://console.developers.google.com/project
2. Choose the new project from the top right project dropdown (only if another project is selected)
3. In the project Dashboard center pane, choose **"API Manager"**
4. In the left Nav pane, choose **"Credentials"**
5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save.
6. In the center pane, choose **"Credentials"** tab.
* Open the **"New credentials"** drop down
* Choose **"OAuth client ID"**
* Choose **"Web application"**
* Application name is freeform, choose something appropriate
* Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com`
* Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback`
* Choose **"Create"**
4. Take note of the **Client ID** and **Client Secret**
1. Create a new project: https://console.developers.google.com/project
2. Choose the new project from the top right project dropdown (only if another project is selected)
3. In the project Dashboard center pane, choose **"API Manager"**
4. In the left Nav pane, choose **"Credentials"**
5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save.
6. In the center pane, choose **"Credentials"** tab.
- Open the **"New credentials"** drop down
- Choose **"OAuth client ID"**
- Choose **"Web application"**
- Application name is freeform, choose something appropriate
- Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com`
- Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback`
- Choose **"Create"**
7. Take note of the **Client ID** and **Client Secret**
It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized.
#### Restrict auth to specific Google groups on your domain. (optional)
1. Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file.
2. Make note of the Client ID for a future step.
3. Under "APIs & Auth", choose APIs.
4. Click on Admin SDK and then Enable API.
5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes:
1. Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file.
2. Make note of the Client ID for a future step.
3. Under "APIs & Auth", choose APIs.
4. Click on Admin SDK and then Enable API.
5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes:
```
https://www.googleapis.com/auth/admin.directory.group.readonly
https://www.googleapis.com/auth/admin.directory.user.readonly
```
6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access.
7. Create or choose an existing administrative email address on the Gmail domain to assign to the ```google-admin-email``` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why.
8. Create or choose an existing email group and set that email to the ```google-group``` flag. You can pass multiple instances of this flag with different groups
and the user will be checked against all the provided groups.
9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the ```google-service-account-json``` flag.
6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access.
7. Create or choose an existing administrative email address on the Gmail domain to assign to the `google-admin-email` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why.
8. Create or choose an existing email group and set that email to the `google-group` flag. You can pass multiple instances of this flag with different groups
and the user will be checked against all the provided groups.
9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the `google-service-account-json` flag.
10. Restart oauth2_proxy.
Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ).
### Azure Auth Provider
1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant.
2. On the App properties page provide the correct Sign-On URL ie `https://internal.yourcompany.com/oauth2/callback`
3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=<YOUR TENANT ID>` commandline option. Default the `common` tenant is used.
1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant.
2. On the App properties page provide the correct Sign-On URL ie `https://internal.yourcompany.com/oauth2/callback`
3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=<YOUR TENANT ID>` commandline option. Default the `common` tenant is used.
The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in.
### Facebook Auth Provider
1. Create a new FB App from <https://developers.facebook.com/>
2. Under FB Login, set your Valid OAuth redirect URIs to `https://internal.yourcompany.com/oauth2/callback`
1. Create a new FB App from <https://developers.facebook.com/>
2. Under FB Login, set your Valid OAuth redirect URIs to `https://internal.yourcompany.com/oauth2/callback`
### GitHub Auth Provider
1. Create a new project: https://github.com/settings/developers
2. Under `Authorization callback URL` enter the correct url ie `https://internal.yourcompany.com/oauth2/callback`
1. Create a new project: https://github.com/settings/developers
2. Under `Authorization callback URL` enter the correct url ie `https://internal.yourcompany.com/oauth2/callback`
The GitHub auth provider supports two additional parameters to restrict authentication to Organization or Team level access. Restricting by org and team is normally accompanied with `--email-domain=*`
@ -121,17 +133,16 @@ If you are using self-hosted GitLab, make sure you set the following to the appr
-redeem-url="<your gitlab url>/oauth/token"
-validate-url="<your gitlab url>/api/v4/user"
### LinkedIn Auth Provider
For LinkedIn, the registration steps are:
1. Create a new project: https://www.linkedin.com/secure/developer
2. In the OAuth User Agreement section:
* In default scope, select r_basicprofile and r_emailaddress.
* In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback`
3. Fill in the remaining required fields and Save.
4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key**
1. Create a new project: https://www.linkedin.com/secure/developer
2. In the OAuth User Agreement section:
- In default scope, select r_basicprofile and r_emailaddress.
- In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback`
3. Fill in the remaining required fields and Save.
4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key**
### Microsoft Azure AD Provider
@ -143,9 +154,9 @@ Take note of your `TenantId` if applicable for your situation. The `TenantId` ca
OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many major providers and several open source projects. This provider was originally built against CoreOS Dex and we will use it as an example.
1. Launch a Dex instance using the [getting started guide](https://github.com/coreos/dex/blob/master/Documentation/getting-started.md).
2. Setup oauth2_proxy with the correct provider and using the default ports and callbacks.
3. Login with the fixture use in the dex guide and run the oauth2_proxy with the following args:
1. Launch a Dex instance using the [getting started guide](https://github.com/coreos/dex/blob/master/Documentation/getting-started.md).
2. Setup oauth2_proxy with the correct provider and using the default ports and callbacks.
3. Login with the fixture use in the dex guide and run the oauth2_proxy with the following args:
-provider oidc
-client-id oauth2_proxy
@ -253,7 +264,7 @@ The following environment variables can be used in place of the corresponding co
There are two recommended configurations.
1) Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`.
1. Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`.
The command line to run `oauth2_proxy` in this configuration would look like this:
@ -270,8 +281,7 @@ The command line to run `oauth2_proxy` in this configuration would look like thi
--client-secret=...
```
2) Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or ....
2. Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or ....
Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an
external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or
@ -321,12 +331,12 @@ The command line to run `oauth2_proxy` in this configuration would look like thi
OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable.
* /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info
* /ping - returns an 200 OK response
* /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies)
* /oauth2/start - a URL that will redirect to start the OAuth cycle
* /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url.
* /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request)
- /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info
- /ping - returns an 200 OK response
- /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies)
- /oauth2/start - a URL that will redirect to start the OAuth cycle
- /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url.
- /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request)
## Request signatures
@ -341,9 +351,9 @@ in `oauthproxy.go`](./oauthproxy.go).
For more information about HMAC request signature validation, read the
following:
* [Amazon Web Services: Signing and Authenticating REST
- [Amazon Web Services: Signing and Authenticating REST
Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html)
* [rc3.org: Using HMAC to authenticate Web service
- [rc3.org: Using HMAC to authenticate Web service
requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/)
## Logging Format
@ -417,3 +427,7 @@ server {
}
}
```
## Contributing
Please see our [Contributing](CONTRIBUTING.md) guidelines.

View File

@ -10,6 +10,7 @@ import (
"github.com/bitly/go-simplejson"
)
// Request parses the request body into a simplejson.Json object
func Request(req *http.Request) (*simplejson.Json, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
@ -32,7 +33,8 @@ func Request(req *http.Request) (*simplejson.Json, error) {
return data, nil
}
func RequestJson(req *http.Request, v interface{}) error {
// RequestJSON parses the request body into the given interface
func RequestJSON(req *http.Request, v interface{}) error {
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Printf("%s %s %s", req.Method, req.URL, err)
@ -50,6 +52,7 @@ func RequestJson(req *http.Request, v interface{}) error {
return json.Unmarshal(body, v)
}
// RequestUnparsedResponse performs a GET and returns the raw response object
func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {

View File

@ -1,20 +1,21 @@
package api
import (
"github.com/bitly/go-simplejson"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/bitly/go-simplejson"
"github.com/stretchr/testify/assert"
)
func testBackend(response_code int, payload string) *httptest.Server {
func testBackend(responseCode int, payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(response_code)
w.WriteHeader(responseCode)
w.Write([]byte(payload))
}))
}

137
configure vendored Executable file
View File

@ -0,0 +1,137 @@
#!/usr/bin/env bash
RED='\033[0;31m'
GREEN='\033[0;32m'
BLUE='\033[0;34m'
NC='\033[0m'
declare -A tools=()
declare -A desired=()
for arg in "$@"; do
case ${arg%%=*} in
"--with-go")
desired[go]="${arg##*=}"
;;
"--with-dep")
desired[dep]="${arg##*=}"
;;
"--help")
printf "${GREEN}$0${NC}\n"
printf " available options:\n"
printf " --with-dep=${BLUE}<path_to_dep_binary>${NC}\n"
printf " --with-go=${BLUE}<path_to_go_binary>${NC}\n"
exit 0
;;
*)
echo "Unknown option: $arg"
exit 2
;;
esac
done
vercomp () {
if [[ $1 == $2 ]]
then
return 0
fi
local IFS=.
local i ver1=($1) ver2=($2)
# fill empty fields in ver1 with zeros
for ((i=${#ver1[@]}; i<${#ver2[@]}; i++))
do
ver1[i]=0
done
for ((i=0; i<${#ver1[@]}; i++))
do
if [[ -z ${ver2[i]} ]]
then
# fill empty fields in ver2 with zeros
ver2[i]=0
fi
if ((10#${ver1[i]} > 10#${ver2[i]}))
then
return 1
fi
if ((10#${ver1[i]} < 10#${ver2[i]}))
then
return 2
fi
done
return 0
}
check_for() {
echo -n "Checking for $1... "
if ! [ -z "${desired[$1]}" ]; then
TOOL_PATH="${desired[$1]}"
else
TOOL_PATH=$(command -v $1)
fi
if ! [ -x "$TOOL_PATH" -a -f "$TOOL_PATH" ]; then
printf "${RED}not found${NC}\n"
cd -
exit 1
else
printf "${GREEN}found${NC}\n"
tools[$1]=$TOOL_PATH
fi
}
check_go_version() {
echo -n "Checking go version... "
GO_VERSION=$(${tools[go]} version | ${tools[awk]} '{where = match($0, /[0-9]\.[0-9]+\.[0-9]*/); if (where != 0) print substr($0, RSTART, RLENGTH)}')
vercomp $GO_VERSION 1.9
case $? in
0) ;&
1)
printf "${GREEN}"
echo $GO_VERSION
printf "${NC}"
;;
2)
printf "${RED}"
echo "$GO_VERSION < 1.9"
exit 1
;;
esac
}
check_docker_version() {
echo -n "Checking docker version... "
DOCKER_VERSION=$(${tools[docker]} version | ${tools[awk]})
}
check_go_env() {
echo -n "Checking \$GOPATH... "
if [ -z "$GOPATH" ]; then
printf "${RED}invalid${NC} - GOPATH not set\n"
exit 1
fi
printf "${GREEN}valid${NC} - $GOPATH\n"
}
cd ${0%/*}
if [ ! -f .env ]; then
rm .env
fi
check_for make
check_for awk
check_for go
check_go_version
check_go_env
check_for dep
echo
cat <<- EOF > .env
MAKE := "${tools[make]}"
GO := "${tools[go]}"
DEP := "${tools[dep]}"
EOF
echo "Environment configuration written to .env"
cd - > /dev/null

View File

@ -24,10 +24,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
}
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
const token = "my access token"
secret, err := base64.URLEncoding.DecodeString(secret_b64)
secret, err := base64.URLEncoding.DecodeString(secretBase64)
assert.Equal(t, nil, err)
c, err := NewCipher([]byte(secret))
assert.Equal(t, nil, err)

View File

@ -5,6 +5,7 @@ import (
"fmt"
)
// Nonce generates a random 16 byte string to be used as a nonce
func Nonce() (nonce string, err error) {
b := make([]byte, 16)
_, err = rand.Read(b)

View File

@ -6,8 +6,14 @@ import (
"strings"
)
// EnvOptions holds program options loaded from the process environment
type EnvOptions map[string]interface{}
// LoadEnvForStruct loads environment variables for each field in an options
// struct passed into it.
//
// Fields in the options struct must have an `env` and `cfg` tag to be read
// from the environment
func (cfg EnvOptions) LoadEnvForStruct(options interface{}) {
val := reflect.ValueOf(options).Elem()
typ := val.Type()

View File

@ -14,10 +14,12 @@ import (
// Lookup passwords in a htpasswd file
// Passwords must be generated with -B for bcrypt or -s for SHA1.
// HtpasswdFile represents the structure of an htpasswd file
type HtpasswdFile struct {
Users map[string]string
}
// NewHtpasswdFromFile constructs an HtpasswdFile from the file at the path given
func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
r, err := os.Open(path)
if err != nil {
@ -27,13 +29,14 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
return NewHtpasswd(r)
}
// NewHtpasswd consctructs an HtpasswdFile from an io.Reader (opened file)
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
csv_reader := csv.NewReader(file)
csv_reader.Comma = ':'
csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true
csvReader := csv.NewReader(file)
csvReader.Comma = ':'
csvReader.Comment = '#'
csvReader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll()
records, err := csvReader.ReadAll()
if err != nil {
return nil, err
}
@ -44,6 +47,7 @@ func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
return h, nil
}
// Validate checks a users password against the HtpasswdFile entries
func (h *HtpasswdFile) Validate(user string, password string) bool {
realPassword, exists := h.Users[user]
if !exists {

View File

@ -20,6 +20,7 @@ func TestSHA(t *testing.T) {
func TestBcrypt(t *testing.T) {
hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1)
assert.Equal(t, err, nil)
hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2)
assert.Equal(t, err, nil)

16
http.go
View File

@ -9,11 +9,13 @@ import (
"time"
)
// Server represents an HTTP server
type Server struct {
Handler http.Handler
Opts *Options
}
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
func (s *Server) ListenAndServe() {
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
s.ServeHTTPS()
@ -22,13 +24,14 @@ func (s *Server) ListenAndServe() {
}
}
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() {
httpAddress := s.Opts.HttpAddress
scheme := ""
HTTPAddress := s.Opts.HTTPAddress
var scheme string
i := strings.Index(httpAddress, "://")
i := strings.Index(HTTPAddress, "://")
if i > -1 {
scheme = httpAddress[0:i]
scheme = HTTPAddress[0:i]
}
var networkType string
@ -39,7 +42,7 @@ func (s *Server) ServeHTTP() {
networkType = scheme
}
slice := strings.SplitN(httpAddress, "//", 2)
slice := strings.SplitN(HTTPAddress, "//", 2)
listenAddr := slice[len(slice)-1]
listener, err := net.Listen(networkType, listenAddr)
@ -57,8 +60,9 @@ func (s *Server) ServeHTTP() {
log.Printf("HTTP: closing %s", listener.Addr())
}
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (s *Server) ServeHTTPS() {
addr := s.Opts.HttpsAddress
addr := s.Opts.HTTPSAddress
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,

View File

@ -27,10 +27,13 @@ type responseLogger struct {
authInfo string
}
// Header returns the ResponseWriter's Header
func (l *responseLogger) Header() http.Header {
return l.w.Header()
}
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
// Header
func (l *responseLogger) ExtractGAPMetadata() {
upstream := l.w.Header().Get("GAP-Upstream-Address")
if upstream != "" {
@ -44,6 +47,7 @@ func (l *responseLogger) ExtractGAPMetadata() {
}
}
// Write writes the response using the ResponseWriter
func (l *responseLogger) Write(b []byte) (int, error) {
if l.status == 0 {
// The status will be StatusOK if WriteHeader has not been called yet
@ -55,16 +59,19 @@ func (l *responseLogger) Write(b []byte) (int, error) {
return size, err
}
// WriteHeader writes the status code for the Response
func (l *responseLogger) WriteHeader(s int) {
l.ExtractGAPMetadata()
l.w.WriteHeader(s)
l.status = s
}
// Status returns the response status code
func (l *responseLogger) Status() int {
return l.status
}
// Size returns teh response size
func (l *responseLogger) Size() int {
return l.size
}
@ -94,6 +101,7 @@ type loggingHandler struct {
logTemplate *template.Template
}
// LoggingHandler provides an http.Handler which logs requests to the HTTP server
func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler {
return loggingHandler{
writer: out,

View File

@ -84,7 +84,7 @@ func main() {
flagSet.Parse(os.Args[1:])
if *showVersion {
fmt.Printf("oauth2_proxy v%s (built with %s)\n", VERSION, runtime.Version())
fmt.Printf("oauth2_proxy %s (built with %s)\n", VERSION, runtime.Version())
return
}

View File

@ -14,14 +14,23 @@ import (
"strings"
"time"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/bitly/oauth2_proxy/providers"
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/providers"
)
const SignatureHeader = "GAP-Signature"
const (
// SignatureHeader is the name of the request header containing the GAP Signature
// Part of hmacauth
SignatureHeader = "GAP-Signature"
var SignatureHeaders []string = []string{
httpScheme = "http"
httpsScheme = "https"
)
// SignatureHeaders contains the headers to be signed by the hmac algorithm
// Part of hmacauth
var SignatureHeaders = []string{
"Content-Length",
"Content-Md5",
"Content-Type",
@ -34,13 +43,14 @@ var SignatureHeaders []string = []string{
"Gap-Auth",
}
// OAuthProxy is the main authentication proxy
type OAuthProxy struct {
CookieSeed string
CookieName string
CSRFCookieName string
CookieDomain string
CookieSecure bool
CookieHttpOnly bool
CookieHTTPOnly bool
CookieExpire time.Duration
CookieRefresh time.Duration
Validator func(string) bool
@ -74,12 +84,15 @@ type OAuthProxy struct {
Footer string
}
// UpstreamProxy represents an upstream server to proxy to
type UpstreamProxy struct {
upstream string
handler http.Handler
auth hmacauth.HmacAuth
}
// ServeHTTP proxies requests to the upstream provider while signing the
// request headers
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("GAP-Upstream-Address", u.upstream)
if u.auth != nil {
@ -89,9 +102,12 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
u.handler.ServeHTTP(w, r)
}
// NewReverseProxy creates a new reverse proxy for proxying requests to upstream
// servers
func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) {
return httputil.NewSingleHostReverseProxy(target)
}
func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) {
director := proxy.Director
proxy.Director = func(req *http.Request) {
@ -102,6 +118,7 @@ func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) {
req.URL.RawQuery = ""
}
}
func setProxyDirector(proxy *httputil.ReverseProxy) {
director := proxy.Director
proxy.Director = func(req *http.Request) {
@ -111,10 +128,13 @@ func setProxyDirector(proxy *httputil.ReverseProxy) {
req.URL.RawQuery = ""
}
}
// NewFileServer creates a http.Handler to serve files from the filesystem
func NewFileServer(path string, filesystemPath string) (proxy http.Handler) {
return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath)))
}
// NewOAuthProxy creates a new instance of OOuthProxy from the options provided
func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
serveMux := http.NewServeMux()
var auth hmacauth.HmacAuth
@ -125,7 +145,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
for _, u := range opts.proxyURLs {
path := u.Path
switch u.Scheme {
case "http", "https":
case httpScheme, httpsScheme:
u.Path = ""
log.Printf("mapping path %q => upstream %q", path, u)
proxy := NewReverseProxy(u)
@ -160,7 +180,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
refresh = fmt.Sprintf("after %s", opts.CookieRefresh)
}
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh)
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh)
var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
@ -177,7 +197,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
CookieSeed: opts.CookieSecret,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
CookieHttpOnly: opts.CookieHttpOnly,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
CookieRefresh: opts.CookieRefresh,
Validator: validator,
@ -209,6 +229,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
}
}
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
// redirect clients to once authenticated
func (p *OAuthProxy) GetRedirectURI(host string) string {
// default to the request Host if not set
if p.redirectURL.Host != "" {
@ -218,9 +240,9 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
u = *p.redirectURL
if u.Scheme == "" {
if p.CookieSecure {
u.Scheme = "https"
u.Scheme = httpsScheme
} else {
u.Scheme = "http"
u.Scheme = httpScheme
}
}
u.Host = host
@ -254,6 +276,8 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
return
}
// MakeSessionCookie creates an http.Cookie containing the authenticated user's
// authentication details
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
@ -265,6 +289,7 @@ func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expirati
return p.makeCookie(req, p.CookieName, value, expiration, now)
}
// MakeCSRFCookie creates a cookie for CSRF
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
}
@ -285,20 +310,25 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
Value: value,
Path: "/",
Domain: p.CookieDomain,
HttpOnly: p.CookieHttpOnly,
HttpOnly: p.CookieHTTPOnly,
Secure: p.CookieSecure,
Expires: now.Add(expiration),
}
}
// ClearCSRFCookie creates a cookie to unset the CSRF cookie stored in the user's
// session
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRFCookie adds a CSRF cookie to the response
func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
}
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clr)
@ -311,10 +341,12 @@ func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Reques
}
}
// SetSessionCookie adds the user's session cookie to the response
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
}
// LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
var age time.Duration
c, err := req.Cookie(p.CookieName)
@ -336,6 +368,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
return session, age, nil
}
// SaveSession creates a new session cookie value and sets this on the response
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil {
@ -345,16 +378,19 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p
return nil
}
// RobotsTxt disallows scraping pages from the OAuthProxy
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
}
// PingPage responds 200 OK to requests
func (p *OAuthProxy) PingPage(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
}
// ErrorPage writes an error response
func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) {
log.Printf("ErrorPage %d %s %s", code, title, message)
rw.WriteHeader(code)
@ -370,16 +406,17 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
p.templates.ExecuteTemplate(rw, "error.html", t)
}
// SignInPage writes the sing in template to the response
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
p.ClearSessionCookie(rw, req)
rw.WriteHeader(code)
redirect_url := req.URL.RequestURI()
redirecURL := req.URL.RequestURI()
if req.Header.Get("X-Auth-Request-Redirect") != "" {
redirect_url = req.Header.Get("X-Auth-Request-Redirect")
redirecURL = req.Header.Get("X-Auth-Request-Redirect")
}
if redirect_url == p.SignInPath {
redirect_url = "/"
if redirecURL == p.SignInPath {
redirecURL = "/"
}
t := struct {
@ -394,7 +431,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
ProviderName: p.provider.Data().ProviderName,
SignInMessage: p.SignInMessage,
CustomLogin: p.displayCustomLoginForm(),
Redirect: redirect_url,
Redirect: redirecURL,
Version: VERSION,
ProxyPrefix: p.ProxyPrefix,
Footer: template.HTML(p.Footer),
@ -402,6 +439,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
}
// ManualSignIn handles basic auth logins to the proxy
func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) {
if req.Method != "POST" || p.HtpasswdFile == nil {
return "", false
@ -419,6 +457,8 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
return "", false
}
// GetRedirect reads the query parameter to get the URL to redirect clients to
// once authenticated with the OAuthProxy
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
err = req.ParseForm()
if err != nil {
@ -433,11 +473,13 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error)
return
}
// IsWhitelistedRequest is used to check if auth should be skipped for this request
func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) {
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path)
}
// IsWhitelistedPath is used to check if the request path is allowed without auth
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
for _, u := range p.compiledRegex {
ok = u.MatchString(path)
@ -479,6 +521,7 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// SignIn serves a page prompting users to sign in
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req)
if err != nil {
@ -500,11 +543,13 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
}
}
// SignOut sends a response to clear the authentication cookie
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
p.ClearSessionCookie(rw, req)
http.Redirect(rw, req, "/", 302)
}
// OAuthStart starts the OAuth2 authentication flow
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
nonce, err := cookie.Nonce()
if err != nil {
@ -521,6 +566,8 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302)
}
// OAuthCallback is the OAuth2 authentication flow callback that finishes the
// OAuth2 authentication flow
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
remoteAddr := getRemoteAddr(req)
@ -582,6 +629,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
}
}
// AuthenticateOnly checks whether the user is currently logged in
func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req)
if status == http.StatusAccepted {
@ -591,6 +639,8 @@ func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request)
}
}
// Proxy proxies the user request if the user is authenticated else it prompts
// them to authenticate
func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req)
if status == http.StatusInternalServerError {
@ -607,6 +657,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
}
}
// Authenticate checks whether a user is authenticated
func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int {
var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req)
@ -620,7 +671,8 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
saveSession = true
}
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil {
var ok bool
if ok, err = p.provider.RefreshSessionIfNeeded(session); err != nil {
log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true
session = nil
@ -653,7 +705,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
}
if saveSession && session != nil {
err := p.SaveSession(rw, req, session)
err = p.SaveSession(rw, req, session)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
return http.StatusInternalServerError
@ -706,6 +758,8 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
return http.StatusAccepted
}
// CheckBasicAuth checks the requests Authorization header for basic auth
// credentials and authenticates these against the proxies HtpasswdFile
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
if p.HtpasswdFile == nil {
return nil, nil

View File

@ -15,8 +15,8 @@ import (
"testing"
"time"
"github.com/bitly/oauth2_proxy/providers"
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert"
)
@ -98,28 +98,28 @@ type TestProvider struct {
ValidToken bool
}
func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider {
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
return &TestProvider{
ProviderData: &providers.ProviderData{
ProviderName: "Test Provider",
LoginURL: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Host: providerURL.Host,
Path: "/oauth/authorize",
},
RedeemURL: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Host: providerURL.Host,
Path: "/oauth/token",
},
ProfileURL: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Host: providerURL.Host,
Path: "/api/v1/profile",
},
Scope: "profile.email",
},
EmailAddress: email_address,
EmailAddress: emailAddress,
}
}
@ -132,11 +132,10 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo
}
func TestBasicAuthPassword(t *testing.T) {
provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r)
url := r.URL
payload := ""
switch url.Path {
var payload string
switch r.URL.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
@ -149,7 +148,7 @@ func TestBasicAuthPassword(t *testing.T) {
w.Write([]byte(payload))
}))
opts := NewOptions()
opts.Upstreams = append(opts.Upstreams, provider_server.URL)
opts.Upstreams = append(opts.Upstreams, providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES
// cipher.
opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -161,13 +160,13 @@ func TestBasicAuthPassword(t *testing.T) {
opts.BasicAuthPassword = "This is a secure password"
opts.Validate()
provider_url, _ := url.Parse(provider_server.URL)
const email_address = "michael.bland@gsa.gov"
const user_name = "michael.bland"
providerURL, _ := url.Parse(providerServer.URL)
const emailAddress = "michael.bland@gsa.gov"
const username = "michael.bland"
opts.provider = NewTestProvider(provider_url, email_address)
opts.provider = NewTestProvider(providerURL, emailAddress)
proxy := NewOAuthProxy(opts, func(email string) bool {
return email == email_address
return email == emailAddress
})
rw := httptest.NewRecorder()
@ -182,10 +181,10 @@ func TestBasicAuthPassword(t *testing.T) {
cookieName := proxy.CookieName
var value string
key_prefix := cookieName + "="
keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix)
value = strings.TrimPrefix(field, keyPrefix)
if value != field {
break
} else {
@ -206,15 +205,15 @@ func TestBasicAuthPassword(t *testing.T) {
rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword))
assert.Equal(t, expectedHeader, rw.Body.String())
provider_server.Close()
providerServer.Close()
}
type PassAccessTokenTest struct {
provider_server *httptest.Server
proxy *OAuthProxy
opts *Options
providerServer *httptest.Server
proxy *OAuthProxy
opts *Options
}
type PassAccessTokenTestOptions struct {
@ -224,12 +223,11 @@ type PassAccessTokenTestOptions struct {
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
t := &PassAccessTokenTest{}
t.provider_server = httptest.NewServer(
t.providerServer = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r)
url := r.URL
payload := ""
switch url.Path {
var payload string
switch r.URL.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
@ -243,7 +241,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
}))
t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL)
t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES
// cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -253,21 +251,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
t.opts.PassAccessToken = opts.PassAccessToken
t.opts.Validate()
provider_url, _ := url.Parse(t.provider_server.URL)
const email_address = "michael.bland@gsa.gov"
providerURL, _ := url.Parse(t.providerServer.URL)
const emailAddress = "michael.bland@gsa.gov"
t.opts.provider = NewTestProvider(provider_url, email_address)
t.opts.provider = NewTestProvider(providerURL, emailAddress)
t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
return email == email_address
return email == emailAddress
})
return t
}
func (pat_test *PassAccessTokenTest) Close() {
pat_test.provider_server.Close()
func (patTest *PassAccessTokenTest) Close() {
patTest.providerServer.Close()
}
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
@ -275,18 +273,18 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
if err != nil {
return 0, ""
}
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
pat_test.proxy.ServeHTTP(rw, req)
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
}
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
cookieName := pat_test.proxy.CookieName
func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) {
cookieName := patTest.proxy.CookieName
var value string
key_prefix := cookieName + "="
keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix)
value = strings.TrimPrefix(field, keyPrefix)
if value != field {
break
} else {
@ -310,18 +308,18 @@ func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code i
})
rw := httptest.NewRecorder()
pat_test.proxy.ServeHTTP(rw, req)
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
})
defer pat_test.Close()
defer patTest.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint()
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
@ -330,7 +328,7 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := pat_test.getRootEndpoint(cookie)
code, payload := patTest.getRootEndpoint(cookie)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
@ -338,13 +336,13 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
}
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false,
})
defer pat_test.Close()
defer patTest.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint()
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
@ -352,7 +350,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
// Now we make a regular request, but the access token header should
// not be present.
code, payload := pat_test.getRootEndpoint(cookie)
code, payload := patTest.getRootEndpoint(cookie)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
@ -360,49 +358,49 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
}
type SignInPageTest struct {
opts *Options
proxy *OAuthProxy
sign_in_regexp *regexp.Regexp
sign_in_provider_regexp *regexp.Regexp
opts *Options
proxy *OAuthProxy
signInRegexp *regexp.Regexp
signInProviderRegexp *regexp.Regexp
}
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
const signInSkipProvider = `>Found<`
func NewSignInPageTest(skipProvider bool) *SignInPageTest {
var sip_test SignInPageTest
var sipTest SignInPageTest
sip_test.opts = NewOptions()
sip_test.opts.CookieSecret = "foobar"
sip_test.opts.ClientID = "bazquux"
sip_test.opts.ClientSecret = "xyzzyplugh"
sip_test.opts.SkipProviderButton = skipProvider
sip_test.opts.Validate()
sipTest.opts = NewOptions()
sipTest.opts.CookieSecret = "foobar"
sipTest.opts.ClientID = "bazquux"
sipTest.opts.ClientSecret = "xyzzyplugh"
sipTest.opts.SkipProviderButton = skipProvider
sipTest.opts.Validate()
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool {
sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool {
return true
})
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider)
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
return &sip_test
return &sipTest
}
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sip_test.proxy.ServeHTTP(rw, req)
sipTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sip_test := NewSignInPageTest(false)
sipTest := NewSignInPageTest(false)
const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint)
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 403, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
@ -414,11 +412,11 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) {
}
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sip_test := NewSignInPageTest(false)
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
sipTest := NewSignInPageTest(false)
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
@ -429,13 +427,13 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
}
func TestSignInPageSkipProvider(t *testing.T) {
sip_test := NewSignInPageTest(true)
sipTest := NewSignInPageTest(true)
const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint)
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code)
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body)
@ -443,13 +441,13 @@ func TestSignInPageSkipProvider(t *testing.T) {
}
func TestSignInPageSkipProviderDirect(t *testing.T) {
sip_test := NewSignInPageTest(true)
sipTest := NewSignInPageTest(true)
const endpoint = "/sign_in"
code, body := sip_test.GetEndpoint(endpoint)
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code)
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body)
@ -457,50 +455,50 @@ func TestSignInPageSkipProviderDirect(t *testing.T) {
}
type ProcessCookieTest struct {
opts *Options
proxy *OAuthProxy
rw *httptest.ResponseRecorder
req *http.Request
provider TestProvider
response_code int
validate_user bool
opts *Options
proxy *OAuthProxy
rw *httptest.ResponseRecorder
req *http.Request
provider TestProvider
responseCode int
validateUser bool
}
type ProcessCookieTestOpts struct {
provider_validate_cookie_response bool
providerValidateCookieResponse bool
}
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
var pc_test ProcessCookieTest
var pcTest ProcessCookieTest
pc_test.opts = NewOptions()
pc_test.opts.ClientID = "bazquux"
pc_test.opts.ClientSecret = "xyzzyplugh"
pc_test.opts.CookieSecret = "0123456789abcdefabcd"
pcTest.opts = NewOptions()
pcTest.opts.ClientID = "bazquux"
pcTest.opts.ClientSecret = "xyzzyplugh"
pcTest.opts.CookieSecret = "0123456789abcdefabcd"
// First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token.
pc_test.opts.CookieRefresh = time.Hour
pc_test.opts.Validate()
pcTest.opts.CookieRefresh = time.Hour
pcTest.opts.Validate()
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
return pc_test.validate_user
pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
pc_test.proxy.provider = &TestProvider{
ValidToken: opts.provider_validate_cookie_response,
pcTest.proxy.provider = &TestProvider{
ValidToken: opts.providerValidateCookieResponse,
}
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pc_test.proxy.CookieRefresh = time.Duration(0)
pc_test.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pc_test.validate_user = true
return &pc_test
pcTest.proxy.CookieRefresh = time.Duration(0)
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pcTest.validateUser = true
return &pcTest
}
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: true,
providerValidateCookieResponse: true,
})
}
@ -522,12 +520,12 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.
}
func TestLoadCookiedSession(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pcTest := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, time.Now())
pcTest.SaveSession(startSession, time.Now())
session, _, err := pc_test.LoadCookiedSession()
session, _, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err)
assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "michael.bland", session.User)
@ -535,9 +533,9 @@ func TestLoadCookiedSession(t *testing.T) {
}
func TestProcessCookieNoCookieError(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pcTest := NewProcessCookieTestWithDefaults()
session, _, err := pc_test.LoadCookiedSession()
session, _, err := pcTest.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil {
t.Errorf("expected nil session. got %#v", session)
@ -545,14 +543,14 @@ func TestProcessCookieNoCookieError(t *testing.T) {
}
func TestProcessCookieRefreshNotSet(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference)
pcTest.SaveSession(startSession, reference)
session, age, err := pc_test.LoadCookiedSession()
session, age, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err)
if age < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", age)
@ -561,13 +559,13 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
}
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference)
pcTest.SaveSession(startSession, reference)
session, _, err := pc_test.LoadCookiedSession()
session, _, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
@ -575,14 +573,14 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
}
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference)
pcTest.SaveSession(startSession, reference)
pc_test.proxy.CookieRefresh = time.Hour
session, _, err := pc_test.LoadCookiedSession()
pcTest.proxy.CookieRefresh = time.Hour
session, _, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
@ -590,10 +588,10 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
}
func NewAuthOnlyEndpointTest() *ProcessCookieTest {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/auth", nil)
return pc_test
pcTest := NewProcessCookieTestWithDefaults()
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
return pcTest
}
func TestAuthOnlyEndpointAccepted(t *testing.T) {
@ -636,7 +634,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
startSession := &providers.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now())
test.validate_user = false
test.validateUser = false
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
@ -645,33 +643,33 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
}
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
var pc_test ProcessCookieTest
var pcTest ProcessCookieTest
pc_test.opts = NewOptions()
pc_test.opts.SetXAuthRequest = true
pc_test.opts.Validate()
pcTest.opts = NewOptions()
pcTest.opts.SetXAuthRequest = true
pcTest.opts.Validate()
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
return pc_test.validate_user
pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
pc_test.proxy.provider = &TestProvider{
pcTest.proxy.provider = &TestProvider{
ValidToken: true,
}
pc_test.validate_user = true
pcTest.validateUser = true
pc_test.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/auth", nil)
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &providers.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
pc_test.SaveSession(startSession, time.Now())
pcTest.SaveSession(startSession, time.Now())
pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req)
assert.Equal(t, http.StatusAccepted, pc_test.rw.Code)
assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0])
assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0])
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0])
assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0])
}
func TestAuthSkippedForPreflightRequests(t *testing.T) {
@ -689,8 +687,8 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
opts.SkipAuthPreflight = true
opts.Validate()
upstream_url, _ := url.Parse(upstream.URL)
opts.provider = NewTestProvider(upstream_url, "")
upstreamURL, _ := url.Parse(upstream.URL)
opts.provider = NewTestProvider(upstreamURL, "")
proxy := NewOAuthProxy(opts, func(string) bool { return false })
rw := httptest.NewRecorder()
@ -723,7 +721,7 @@ func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Req
type SignatureTest struct {
opts *Options
upstream *httptest.Server
upstream_host string
upstreamHost string
provider *httptest.Server
header http.Header
rw *httptest.ResponseRecorder
@ -740,20 +738,20 @@ func NewSignatureTest() *SignatureTest {
authenticator := &SignatureAuthenticator{}
upstream := httptest.NewServer(
http.HandlerFunc(authenticator.Authenticate))
upstream_url, _ := url.Parse(upstream.URL)
upstreamURL, _ := url.Parse(upstream.URL)
opts.Upstreams = append(opts.Upstreams, upstream.URL)
providerHandler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"access_token": "my_auth_token"}`))
}
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
provider_url, _ := url.Parse(provider.URL)
opts.provider = NewTestProvider(provider_url, "mbland@acm.org")
providerURL, _ := url.Parse(provider.URL)
opts.provider = NewTestProvider(providerURL, "mbland@acm.org")
return &SignatureTest{
opts,
upstream,
upstream_url.Host,
upstreamURL.Host,
provider,
make(http.Header),
httptest.NewRecorder(),

View File

@ -13,16 +13,17 @@ import (
"strings"
"time"
"github.com/bitly/oauth2_proxy/providers"
oidc "github.com/coreos/go-oidc"
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers"
)
// Configuration Options that can be set by Command Line Flag, or Config File
// Options holds Configuration Options that can be set by Command Line Flag,
// or Config File
type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"`
HttpAddress string `flag:"http-address" cfg:"http_address"`
HttpsAddress string `flag:"https-address" cfg:"https_address"`
HTTPAddress string `flag:"http-address" cfg:"http_address"`
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
RedirectURL string `flag:"redirect-url" cfg:"redirect_url"`
ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
@ -48,7 +49,7 @@ type Options struct {
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"`
CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
Upstreams []string `flag:"upstream" cfg:"upstreams"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
@ -88,20 +89,22 @@ type Options struct {
oidcVerifier *oidc.IDTokenVerifier
}
// SignatureData holds hmacauth signature hash and key
type SignatureData struct {
hash crypto.Hash
key string
}
// NewOptions constructs a new Options with defaulted values
func NewOptions() *Options {
return &Options{
ProxyPrefix: "/oauth2",
HttpAddress: "127.0.0.1:4180",
HttpsAddress: ":443",
HTTPAddress: "127.0.0.1:4180",
HTTPSAddress: ":443",
DisplayHtpasswdForm: true,
CookieName: "_oauth2_proxy",
CookieSecure: true,
CookieHttpOnly: true,
CookieHTTPOnly: true,
CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(0),
SetXAuthRequest: false,
@ -116,15 +119,17 @@ func NewOptions() *Options {
}
}
func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) {
parsed, err := url.Parse(to_parse)
func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) {
parsed, err := url.Parse(toParse)
if err != nil {
return nil, append(msgs, fmt.Sprintf(
"error parsing %s-url=%q %s", urltype, to_parse, err))
"error parsing %s-url=%q %s", urltype, toParse, err))
}
return parsed, msgs
}
// Validate checks that required options are set and validates those that they
// are of the correct format
func (o *Options) Validate() error {
if o.SSLInsecureSkipVerify {
// TODO: Accept a certificate bundle.
@ -190,17 +195,17 @@ func (o *Options) Validate() error {
msgs = parseProviderInfo(o, msgs)
if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
valid_cookie_secret_size := false
validCookieSecretSize := false
for _, i := range []int{16, 24, 32} {
if len(secretBytes(o.CookieSecret)) == i {
valid_cookie_secret_size = true
validCookieSecretSize = true
}
}
var decoded bool
if string(secretBytes(o.CookieSecret)) != o.CookieSecret {
decoded = true
}
if valid_cookie_secret_size == false {
if validCookieSecretSize == false {
var suffix string
if decoded {
suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret)
@ -294,12 +299,13 @@ func parseSignatureKey(o *Options, msgs []string) []string {
}
algorithm, secretKey := components[0], components[1]
if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil {
var hash crypto.Hash
var err error
if hash, err = hmacauth.DigestNameToCryptoHash(algorithm); err != nil {
return append(msgs, "unsupported signature hash algorithm: "+
o.SignatureKey)
} else {
o.signatureData = &SignatureData{hash, secretKey}
}
o.signatureData = &SignatureData{hash, secretKey}
return msgs
}

View File

@ -88,9 +88,9 @@ func TestProxyURLs(t *testing.T) {
o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081")
assert.Equal(t, nil, o.Validate())
expected := []*url.URL{
&url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
// note the '/' was added
&url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
}
assert.Equal(t, expected, o.proxyURLs)
}

View File

@ -3,18 +3,21 @@ package providers
import (
"errors"
"fmt"
"github.com/bitly/go-simplejson"
"github.com/bitly/oauth2_proxy/api"
"log"
"net/http"
"net/url"
"github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api"
)
// AzureProvider represents an Azure based Identity Provider
type AzureProvider struct {
*ProviderData
Tenant string
}
// NewAzureProvider initiates a new AzureProvider
func NewAzureProvider(p *ProviderData) *AzureProvider {
p.ProviderName = "Azure"
@ -39,6 +42,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
return &AzureProvider{ProviderData: p}
}
// Configure defaults the AzureProvider configuration options
func (p *AzureProvider) Configure(tenant string) {
p.Tenant = tenant
if tenant == "" {
@ -60,9 +64,9 @@ func (p *AzureProvider) Configure(tenant string) {
}
}
func getAzureHeader(access_token string) http.Header {
func getAzureHeader(accessToken string) http.Header {
header := make(http.Header)
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header
}
@ -83,6 +87,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
return email, err
}
// GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
var email string
var err error

View File

@ -110,8 +110,7 @@ func testAzureBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path || url.RawQuery != query {
if r.URL.Path != path || r.URL.RawQuery != query {
w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403)

View File

@ -6,13 +6,15 @@ import (
"net/http"
"net/url"
"github.com/bitly/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/api"
)
// FacebookProvider represents an Facebook based Identity Provider
type FacebookProvider struct {
*ProviderData
}
// NewFacebookProvider initiates a new FacebookProvider
func NewFacebookProvider(p *ProviderData) *FacebookProvider {
p.ProviderName = "Facebook"
if p.LoginURL.String() == "" {
@ -43,14 +45,15 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
return &FacebookProvider{ProviderData: p}
}
func getFacebookHeader(access_token string) http.Header {
func getFacebookHeader(accessToken string) http.Header {
header := make(http.Header)
header.Set("Accept", "application/json")
header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header
}
// GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
if s.AccessToken == "" {
return "", errors.New("missing access token")
@ -65,7 +68,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
Email string
}
var r result
err = api.RequestJson(req, &r)
err = api.RequestJSON(req, &r)
if err != nil {
return "", err
}
@ -75,6 +78,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
return r.Email, nil
}
// ValidateSessionState validates the AccessToken
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
}

View File

@ -12,12 +12,14 @@ import (
"strings"
)
// GitHubProvider represents an GitHub based Identity Provider
type GitHubProvider struct {
*ProviderData
Org string
Team string
}
// NewGitHubProvider initiates a new GitHubProvider
func NewGitHubProvider(p *ProviderData) *GitHubProvider {
p.ProviderName = "GitHub"
if p.LoginURL == nil || p.LoginURL.String() == "" {
@ -47,6 +49,8 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider {
}
return &GitHubProvider{ProviderData: p}
}
// SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
func (p *GitHubProvider) SetOrgTeam(org, team string) {
p.Org = org
p.Team = team
@ -106,7 +110,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
}
orgs = append(orgs, op...)
pn += 1
pn++
}
var presentOrgs []string
@ -186,7 +190,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams)
} else {
var allOrgs []string
for org, _ := range presentOrgs {
for org := range presentOrgs {
allOrgs = append(allOrgs, org)
}
log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs)
@ -194,6 +198,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
return false, nil
}
// GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
var emails []struct {
@ -251,6 +256,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
return "", nil
}
// GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
var user struct {
Login string `json:"login"`

View File

@ -29,19 +29,18 @@ func testGitHubProvider(hostname string) *GitHubProvider {
func testGitHubBackend(payload []string) *httptest.Server {
pathToQueryMap := map[string][]string{
"/user": []string{""},
"/user/emails": []string{""},
"/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"},
"/user": {""},
"/user/emails": {""},
"/user/orgs": {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"},
}
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
query, ok := pathToQueryMap[url.Path]
query, ok := pathToQueryMap[r.URL.Path]
validQuery := false
index := 0
for i, q := range query {
if q == url.RawQuery {
if q == r.URL.RawQuery {
validQuery = true
index = i
}

View File

@ -5,13 +5,15 @@ import (
"net/http"
"net/url"
"github.com/bitly/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/api"
)
// GitLabProvider represents an GitLab based Identity Provider
type GitLabProvider struct {
*ProviderData
}
// NewGitLabProvider initiates a new GitLabProvider
func NewGitLabProvider(p *ProviderData) *GitLabProvider {
p.ProviderName = "GitLab"
if p.LoginURL == nil || p.LoginURL.String() == "" {
@ -41,6 +43,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
return &GitLabProvider{ProviderData: p}
}
// GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
req, err := http.NewRequest("GET",

View File

@ -33,8 +33,7 @@ func testGitLabBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path || url.RawQuery != query {
if r.URL.Path != path || r.URL.RawQuery != query {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
@ -87,8 +86,8 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}")
defer b.Close()
b_url, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host)
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
@ -102,8 +101,8 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitLabBackend("unused payload")
defer b.Close()
b_url, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host)
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
@ -118,8 +117,8 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitLabBackend("{\"foo\": \"bar\"}")
defer b.Close()
b_url, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host)
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)

View File

@ -20,6 +20,7 @@ import (
"google.golang.org/api/googleapi"
)
// GoogleProvider represents an Google based Identity Provider
type GoogleProvider struct {
*ProviderData
RedeemRefreshURL *url.URL
@ -28,6 +29,7 @@ type GoogleProvider struct {
GroupValidator func(string) bool
}
// NewGoogleProvider initiates a new GoogleProvider
func NewGoogleProvider(p *ProviderData) *GoogleProvider {
p.ProviderName = "Google"
if p.LoginURL.String() == "" {
@ -62,7 +64,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
}
}
func emailFromIdToken(idToken string) (string, error) {
func emailFromIDToken(idToken string) (string, error) {
// id_token is a base64 encode ID token payload
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
@ -90,6 +92,7 @@ func emailFromIdToken(idToken string) (string, error) {
return email.Email, nil
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
if code == "" {
err = errors.New("missing code")
@ -129,14 +132,14 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
return
}
var email string
email, err = emailFromIdToken(jsonResponse.IdToken)
email, err = emailFromIDToken(jsonResponse.IDToken)
if err != nil {
return
}
@ -249,6 +252,8 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
return p.GroupValidator(email)
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil

View File

@ -81,7 +81,7 @@ type redeemResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"`
IDToken string `json:"id_token"`
}
func TestGoogleProviderGetEmailAddress(t *testing.T) {
@ -90,7 +90,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
AccessToken: "a1234",
ExpiresIn: 10,
RefreshToken: "refresh12345",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
})
assert.Equal(t, nil, err)
var server *httptest.Server
@ -127,7 +127,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{
AccessToken: "a1234",
IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
IDToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
})
assert.Equal(t, nil, err)
var server *httptest.Server
@ -146,7 +146,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
body, err := json.Marshal(redeemResponse{
AccessToken: "a1234",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
})
assert.Equal(t, nil, err)
var server *httptest.Server
@ -165,7 +165,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{
AccessToken: "a1234",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
})
assert.Equal(t, nil, err)
var server *httptest.Server

View File

@ -6,7 +6,7 @@ import (
"net/http"
"net/url"
"github.com/bitly/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/api"
)
// stripToken is a helper function to obfuscate "access_token"
@ -46,13 +46,13 @@ func stripParam(param, endpoint string) string {
}
// validateToken returns true if token is valid
func validateToken(p Provider, access_token string, header http.Header) bool {
if access_token == "" || p.Data().ValidateURL == nil {
func validateToken(p Provider, accessToken string, header http.Header) bool {
if accessToken == "" || p.Data().ValidateURL == nil {
return false
}
endpoint := p.Data().ValidateURL.String()
if len(header) == 0 {
params := url.Values{"access_token": {access_token}}
params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode()
}
resp, err := api.RequestUnparsedResponse(endpoint, header)
@ -72,8 +72,3 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
return false
}
func updateURL(url *url.URL, hostname string) {
url.Scheme = "http"
url.Host = hostname
}

View File

@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/assert"
)
func updateURL(url *url.URL, hostname string) {
url.Scheme = "http"
url.Host = hostname
}
type ValidateSessionStateTestProvider struct {
*ProviderData
}
@ -25,28 +30,28 @@ func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState
}
type ValidateSessionStateTest struct {
backend *httptest.Server
response_code int
provider *ValidateSessionStateTestProvider
header http.Header
backend *httptest.Server
responseCode int
provider *ValidateSessionStateTestProvider
header http.Header
}
func NewValidateSessionStateTest() *ValidateSessionStateTest {
var vt_test ValidateSessionStateTest
var vtTest ValidateSessionStateTest
vt_test.backend = httptest.NewServer(
vtTest.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/tokeninfo" {
w.WriteHeader(500)
w.Write([]byte("unknown URL"))
}
token_param := r.FormValue("access_token")
if token_param == "" {
tokenParam := r.FormValue("access_token")
if tokenParam == "" {
missing := false
received_headers := r.Header
for k, _ := range vt_test.header {
received := received_headers.Get(k)
expected := vt_test.header.Get(k)
receivedHeaders := r.Header
for k := range vtTest.header {
received := receivedHeaders.Get(k)
expected := vtTest.header.Get(k)
if received == "" || received != expected {
missing = true
}
@ -56,68 +61,68 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest {
w.Write([]byte("no token param and missing or incorrect headers"))
}
}
w.WriteHeader(vt_test.response_code)
w.WriteHeader(vtTest.responseCode)
w.Write([]byte("only code matters; contents disregarded"))
}))
backend_url, _ := url.Parse(vt_test.backend.URL)
vt_test.provider = &ValidateSessionStateTestProvider{
backendURL, _ := url.Parse(vtTest.backend.URL)
vtTest.provider = &ValidateSessionStateTestProvider{
ProviderData: &ProviderData{
ValidateURL: &url.URL{
Scheme: "http",
Host: backend_url.Host,
Host: backendURL.Host,
Path: "/oauth/tokeninfo",
},
},
}
vt_test.response_code = 200
return &vt_test
vtTest.responseCode = 200
return &vtTest
}
func (vt_test *ValidateSessionStateTest) Close() {
vt_test.backend.Close()
func (vtTest *ValidateSessionStateTest) Close() {
vtTest.backend.Close()
}
func TestValidateSessionStateValidToken(t *testing.T) {
vt_test := NewValidateSessionStateTest()
defer vt_test.Close()
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
vtTest := NewValidateSessionStateTest()
defer vtTest.Close()
assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil))
}
func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
vt_test := NewValidateSessionStateTest()
defer vt_test.Close()
vt_test.header = make(http.Header)
vt_test.header.Set("Authorization", "Bearer foobar")
vtTest := NewValidateSessionStateTest()
defer vtTest.Close()
vtTest.header = make(http.Header)
vtTest.header.Set("Authorization", "Bearer foobar")
assert.Equal(t, true,
validateToken(vt_test.provider, "foobar", vt_test.header))
validateToken(vtTest.provider, "foobar", vtTest.header))
}
func TestValidateSessionStateEmptyToken(t *testing.T) {
vt_test := NewValidateSessionStateTest()
defer vt_test.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
vtTest := NewValidateSessionStateTest()
defer vtTest.Close()
assert.Equal(t, false, validateToken(vtTest.provider, "", nil))
}
func TestValidateSessionStateEmptyValidateURL(t *testing.T) {
vt_test := NewValidateSessionStateTest()
defer vt_test.Close()
vt_test.provider.Data().ValidateURL = nil
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
vtTest := NewValidateSessionStateTest()
defer vtTest.Close()
vtTest.provider.Data().ValidateURL = nil
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
}
func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateSessionStateTest()
vtTest := NewValidateSessionStateTest()
// Close immediately to simulate a network failure
vt_test.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
vtTest.Close()
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
}
func TestValidateSessionStateExpiredToken(t *testing.T) {
vt_test := NewValidateSessionStateTest()
defer vt_test.Close()
vt_test.response_code = 401
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
vtTest := NewValidateSessionStateTest()
defer vtTest.Close()
vtTest.responseCode = 401
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
}
func TestStripTokenNotPresent(t *testing.T) {

View File

@ -6,13 +6,15 @@ import (
"net/http"
"net/url"
"github.com/bitly/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/api"
)
// LinkedInProvider represents an LinkedIn based Identity Provider
type LinkedInProvider struct {
*ProviderData
}
// NewLinkedInProvider initiates a new LinkedInProvider
func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
p.ProviderName = "LinkedIn"
if p.LoginURL.String() == "" {
@ -39,14 +41,15 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
return &LinkedInProvider{ProviderData: p}
}
func getLinkedInHeader(access_token string) http.Header {
func getLinkedInHeader(accessToken string) http.Header {
header := make(http.Header)
header.Set("Accept", "application/json")
header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header
}
// GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
if s.AccessToken == "" {
return "", errors.New("missing access token")
@ -69,6 +72,7 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
return email, nil
}
// ValidateSessionState validates the AccessToken
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
}

View File

@ -31,8 +31,7 @@ func testLinkedInBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path {
if r.URL.Path != path {
w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403)
@ -95,8 +94,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
b := testLinkedInBackend(`"user@linkedin.com"`)
defer b.Close()
b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host)
bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
@ -108,8 +107,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testLinkedInBackend("unused payload")
defer b.Close()
b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host)
bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
@ -124,8 +123,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testLinkedInBackend("{\"foo\": \"bar\"}")
defer b.Close()
b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host)
bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)

View File

@ -10,17 +10,20 @@ import (
oidc "github.com/coreos/go-oidc"
)
// OIDCProvider represents an OIDC based Identity Provider
type OIDCProvider struct {
*ProviderData
Verifier *oidc.IDTokenVerifier
}
// NewOIDCProvider initiates a new OIDCProvider
func NewOIDCProvider(p *ProviderData) *OIDCProvider {
p.ProviderName = "OpenID Connect"
return &OIDCProvider{ProviderData: p}
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
ctx := context.Background()
c := oauth2.Config{
@ -73,6 +76,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
return
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
//
// WARNGING: This implementation is broken and does not check with the upstream
// OIDC provider before refreshing the session
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil

View File

@ -4,6 +4,8 @@ import (
"net/url"
)
// ProviderData contains information required to configure all implementations
// of OAuth2 providers
type ProviderData struct {
ProviderName string
ClientID string
@ -17,4 +19,5 @@ type ProviderData struct {
ApprovalPrompt string
}
// Data returns the ProviderData
func (p *ProviderData) Data() *ProviderData { return p }

View File

@ -9,9 +9,10 @@ import (
"net/http"
"net/url"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/cookie"
)
// Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
if code == "" {
err = errors.New("missing code")
@ -102,6 +103,7 @@ func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *Session
return DecodeSessionState(v, c)
}
// GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}
@ -117,11 +119,13 @@ func (p *ProviderData) ValidateGroup(email string) bool {
return true
}
// ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, nil)
}
// RefreshSessionIfNeeded
// RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return false, nil
}

View File

@ -1,9 +1,10 @@
package providers
import (
"github.com/bitly/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/cookie"
)
// Provider represents an upstream identity provider implementation
type Provider interface {
Data() *ProviderData
GetEmailAddress(*SessionState) (string, error)
@ -17,6 +18,7 @@ type Provider interface {
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
}
// New provides a new Provider based on the configured provider string
func New(provider string, p *ProviderData) Provider {
switch provider {
case "linkedin":

View File

@ -6,9 +6,10 @@ import (
"strings"
"time"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/cookie"
)
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string
ExpiresOn time.Time
@ -17,6 +18,7 @@ type SessionState struct {
User string
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
@ -24,6 +26,7 @@ func (s *SessionState) IsExpired() bool {
return false
}
// String constructs a summary of the session state
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.accountInfo())
if s.AccessToken != "" {
@ -38,6 +41,7 @@ func (s *SessionState) String() string {
return o + "}"
}
// EncodeSessionState returns string representation of the current session
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" {
return s.accountInfo(), nil
@ -49,6 +53,7 @@ func (s *SessionState) accountInfo() string {
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
}
// EncryptedString encrypts the session state into a cookie string
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
var err error
if c == nil {
@ -84,6 +89,7 @@ func decodeSessionStatePlain(v string) (s *SessionState, err error) {
return &SessionState{User: user, Email: email}, nil
}
// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
if c == nil {
return decodeSessionStatePlain(v)

View File

@ -6,7 +6,7 @@ import (
"testing"
"time"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/stretchr/testify/assert"
)

View File

@ -4,13 +4,16 @@ import (
"strings"
)
// StringArray is a type alias for a slice of strings
type StringArray []string
// Set appends a string to the StringArray
func (a *StringArray) Set(s string) error {
*a = append(*a, s)
return nil
}
// String joins elements of the StringArray into a single comma separated string
func (a *StringArray) String() string {
return strings.Join(*a, ",")
}

View File

@ -142,7 +142,7 @@ func getTemplates() *template.Template {
<footer>
{{ if eq .Footer "-" }}
{{ else if eq .Footer ""}}
Secured with <a href="https://github.com/bitly/oauth2_proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}}
Secured with <a href="https://github.com/pusher/oauth2_proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}}
{{ else }}
{{.Footer}}
{{ end }}

14
test.sh
View File

@ -1,14 +0,0 @@
#!/bin/bash
EXIT_CODE=0
echo "gofmt"
diff -u <(echo -n) <(gofmt -d $(find . -type f -name '*.go' -not -path "./vendor/*")) || EXIT_CODE=1
for pkg in $(go list ./... | grep -v '/vendor/' ); do
echo "testing $pkg"
echo "go vet $pkg"
go vet "$pkg" || EXIT_CODE=1
echo "go test -v $pkg"
go test -v -timeout 90s "$pkg" || EXIT_CODE=1
echo "go test -v -race $pkg"
GOMAXPROCS=4 go test -v -timeout 90s0s -race "$pkg" || EXIT_CODE=1
done
exit $EXIT_CODE

View File

@ -10,11 +10,13 @@ import (
"unsafe"
)
// UserMap holds information from the authenticated emails file
type UserMap struct {
usersFile string
m unsafe.Pointer
}
// NewUserMap parses the authenticated emails file into a new UserMap
func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
um := &UserMap{usersFile: usersFile}
m := make(map[string]bool)
@ -30,23 +32,26 @@ func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
return um
}
// IsValid checks if an email is allowed
func (um *UserMap) IsValid(email string) (result bool) {
m := *(*map[string]bool)(atomic.LoadPointer(&um.m))
_, result = m[email]
return
}
// LoadAuthenticatedEmailsFile loads the authenticated emails file from disk
// and parses the contents as CSV
func (um *UserMap) LoadAuthenticatedEmailsFile() {
r, err := os.Open(um.usersFile)
if err != nil {
log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
}
defer r.Close()
csv_reader := csv.NewReader(r)
csv_reader.Comma = ','
csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll()
csvReader := csv.NewReader(r)
csvReader.Comma = ','
csvReader.Comment = '#'
csvReader.TrimLeadingSpace = true
records, err := csvReader.ReadAll()
if err != nil {
log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
return
@ -91,6 +96,7 @@ func newValidatorImpl(domains []string, usersFile string,
return validator
}
// NewValidator constructs a function to validate email addresses
func NewValidator(domains []string, usersFile string) func(string) bool {
return newValidatorImpl(domains, usersFile, nil, func() {})
}

View File

@ -8,15 +8,15 @@ import (
)
type ValidatorTest struct {
auth_email_file *os.File
done chan bool
update_seen bool
authEmailFile *os.File
done chan bool
updateSeen bool
}
func NewValidatorTest(t *testing.T) *ValidatorTest {
vt := &ValidatorTest{}
var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil {
t.Fatal("failed to create temp file: " + err.Error())
}
@ -26,27 +26,27 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
func (vt *ValidatorTest) TearDown() {
vt.done <- true
os.Remove(vt.auth_email_file.Name())
os.Remove(vt.authEmailFile.Name())
}
func (vt *ValidatorTest) NewValidator(domains []string,
updated chan<- bool) func(string) bool {
return newValidatorImpl(domains, vt.auth_email_file.Name(),
return newValidatorImpl(domains, vt.authEmailFile.Name(),
vt.done, func() {
if vt.update_seen == false {
if vt.updateSeen == false {
updated <- true
vt.update_seen = true
vt.updateSeen = true
}
})
}
// This will close vt.auth_email_file.
// This will close vt.authEmailFile.
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
defer vt.auth_email_file.Close()
vt.auth_email_file.WriteString(strings.Join(emails, "\n"))
if err := vt.auth_email_file.Close(); err != nil {
defer vt.authEmailFile.Close()
vt.authEmailFile.WriteString(strings.Join(emails, "\n"))
if err := vt.authEmailFile.Close(); err != nil {
t.Fatal("failed to close temp file " +
vt.auth_email_file.Name() + ": " + err.Error())
vt.authEmailFile.Name() + ": " + err.Error())
}
}

View File

@ -12,18 +12,18 @@ import (
func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver(
t *testing.T, emails []string) {
orig_file := vt.auth_email_file
origFile := vt.authEmailFile
var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil {
t.Fatal("failed to create temp file for copy: " + err.Error())
}
vt.WriteEmails(t, emails)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
if err != nil {
t.Fatal("failed to copy over temp file: " + err.Error())
}
vt.auth_email_file = orig_file
vt.authEmailFile = origFile
}
func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {

View File

@ -10,8 +10,8 @@ import (
func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
var err error
vt.auth_email_file, err = os.OpenFile(
vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600)
vt.authEmailFile, err = os.OpenFile(
vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
t.Fatal("failed to re-open temp file for updates")
}
@ -20,24 +20,24 @@ func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace(
t *testing.T, emails []string) {
orig_file := vt.auth_email_file
origFile := vt.authEmailFile
var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil {
t.Fatal("failed to create temp file for rename and replace: " +
err.Error())
}
vt.WriteEmails(t, emails)
moved_name := orig_file.Name() + "-moved"
err = os.Rename(orig_file.Name(), moved_name)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
movedName := origFile.Name() + "-moved"
err = os.Rename(origFile.Name(), movedName)
err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
if err != nil {
t.Fatal("failed to rename and replace temp file: " +
err.Error())
}
vt.auth_email_file = orig_file
os.Remove(moved_name)
vt.authEmailFile = origFile
os.Remove(movedName)
}
func TestValidatorOverwriteEmailListDirectly(t *testing.T) {

View File

@ -1,3 +1,4 @@
package main
const VERSION = "2.2.1-alpha"
// VERSION contains version information
var VERSION = "undefined"

View File

@ -8,16 +8,18 @@ import (
"path/filepath"
"time"
"gopkg.in/fsnotify.v1"
fsnotify "gopkg.in/fsnotify/fsnotify.v1"
)
// WaitForReplacement waits for a file to exist on disk and then starts a watch
// for the file
func WaitForReplacement(filename string, op fsnotify.Op,
watcher *fsnotify.Watcher) {
const sleep_interval = 50 * time.Millisecond
const sleepInterval = 50 * time.Millisecond
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
if op&fsnotify.Chmod != 0 {
time.Sleep(sleep_interval)
time.Sleep(sleepInterval)
}
for {
if _, err := os.Stat(filename); err == nil {
@ -26,10 +28,11 @@ func WaitForReplacement(filename string, op fsnotify.Op,
return
}
}
time.Sleep(sleep_interval)
time.Sleep(sleepInterval)
}
}
// WatchForUpdates performs an action every time a file on disk is updated
func WatchForUpdates(filename string, done <-chan bool, action func()) {
filename = filepath.Clean(filename)
watcher, err := fsnotify.NewWatcher()
@ -56,7 +59,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) {
}
log.Printf("reloading after event: %s", event)
action()
case err := <-watcher.Errors:
case err = <-watcher.Errors:
log.Printf("error watching %s: %s", filename, err)
}
}