initial release

This commit is contained in:
Bobby DeSimone 2019-01-02 12:13:36 -08:00
commit d56c889224
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
62 changed files with 8229 additions and 0 deletions

96
.gitignore vendored Normal file
View file

@ -0,0 +1,96 @@
pem
env
coverage.txt
*.pem
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
.cover
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
# Other dirs
/bin/
/pkg/
# Without this, the *.[568vq] above ignores this folder.
!**/graphrbac/1.6
# Ruby
website/vendor
website/.bundle
website/build
website/tmp
# Vagrant
.vagrant/
Vagrantfile
# Configs
*.hcl
!command/agent/config/test-fixtures/config.hcl
!command/agent/config/test-fixtures/config-embedded-type.hcl
.DS_Store
.idea
.vscode
dist/*
tags
# Editor backups
*~
*.sw[a-z]
# IntelliJ IDEA project files
.idea
*.ipr
*.iml
# compiled output
ui/dist
ui/tmp
ui/root
http/bindata_assetfs.go
# dependencies
ui/node_modules
ui/bower_components
# misc
ui/.DS_Store
ui/.sass-cache
ui/connect.lock
ui/coverage/*
ui/libpeerconnection.log
ui/npm-debug.log
ui/testem.log
# used for JS acceptance tests
ui/tests/helpers/vault-keys.js
ui/vault-ui-integration-server.pid
# for building static assets
node_modules
package-lock.json

18
.travis.yml Normal file
View file

@ -0,0 +1,18 @@
---
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip
fast_finish: true
install:
- go get github.com/golang/lint/golint
- go get honnef.co/go/tools/cmd/staticcheck
script:
- env GO111MODULE=on make all
- env GO111MODULE=on make cover
- env GO111MODULE=on make release
# after_success:
# - bash <(curl -s https://codecov.io/bash)

88
3RD-PARTY Normal file
View file

@ -0,0 +1,88 @@
Third Party Licenses
Go
SPDX-License-Identifier: BSD-3-Clause
https://golang.org/LICENSE
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Upspin
SPDX-License-Identifier: BSD-3-Clause
https://github.com/upspin/upspin/blob/master/LICENSE
Copyright (c) 2016 The Upspin Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
buzzfeed/sso (fork of bitly/oauth2_proxy)
SPDX-License-Identifier: MIT
https://github.com/buzzfeed/sso/blob/master/LICENSE
https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

14
Dockerfile Normal file
View file

@ -0,0 +1,14 @@
FROM golang:alpine as build
RUN apk --update --no-cache add ca-certificates git make
ENV CGO_ENABLED=0
ENV GO111MODULE=on
WORKDIR /go/src/github.com/pomerium/pomerium
COPY . .
RUN make
FROM scratch
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
WORKDIR /pomerium
COPY --from=build /bin/* /bin/

201
LICENSE Normal file
View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019 Pomerium
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

113
Makefile Normal file
View file

@ -0,0 +1,113 @@
# Setup name variables for the package/tool
PREFIX?=$(shell pwd)
NAME := pomerium
PKG := github.com/pomerium/$(NAME)
BUILDDIR := ${PREFIX}/dist
BINDIR := ${PREFIX}/bin
GO111MODULE=on
CGO_ENABLED := 0
# Set any default go build tags
BUILDTAGS :=
# Populate version variables
# Add to compile time flags
VERSION := $(shell cat VERSION)
GITCOMMIT := $(shell git rev-parse --short HEAD)
GITUNTRACKEDCHANGES := $(shell git status --porcelain --untracked-files=no)
BUILDMETA:=
ifneq ($(GITUNTRACKEDCHANGES),)
BUILDMETA := dirty
endif
CTIMEVAR=-X $(PKG)/internal/version.GitCommit=$(GITCOMMIT) \
-X $(PKG)/internal/version.Version=$(VERSION) \
-X $(PKG)/internal/version.BuildMeta=$(BUILDMETA) \
-X $(PKG)/internal/version.ProjectName=$(NAME) \
-X $(PKG)/internal/version.ProjectURL=$(PKG)
GO_LDFLAGS=-ldflags "-w $(CTIMEVAR)"
GOOSARCHES = linux/amd64 darwin/amd64 windows/amd64
.PHONY: all
all: clean build fmt lint vet test ## Runs a clean, build, fmt, lint, test, and vet.
.PHONY: build
build: ## Builds dynamic executables and/or packages.
@echo "==> $@"
@CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)"
.PHONY: fmt
fmt: ## Verifies all files have been `gofmt`ed.
@echo "==> $@"
@gofmt -s -l . | grep -v '.pb.go:' | grep -v vendor | tee /dev/stderr
.PHONY: lint
lint: ## Verifies `golint` passes.
@echo "==> $@"
@golint ./... | grep -v '.pb.go:' | grep -v vendor | tee /dev/stderr
.PHONY: staticcheck
staticcheck: ## Verifies `staticcheck` passes
@echo "+ $@"
@staticcheck $(shell go list ./... | grep -v vendor) | grep -v '.pb.go:' | tee /dev/stderr
.PHONY: vet
vet: ## Verifies `go vet` passes.
@echo "==> $@"
@go vet $(shell go list ./... | grep -v vendor) | grep -v '.pb.go:' | tee /dev/stderr
.PHONY: test
test: ## Runs the go tests.
@echo "==> $@"
@go test -tags "$(BUILDTAGS)" $(shell go list ./... | grep -v vendor)
.PHONY: cover
cover: ## Runs go test with coverage
@echo "" > coverage.txt
@for d in $(shell go list ./... | grep -v vendor); do \
go test -race -coverprofile=profile.out -covermode=atomic "$$d"; \
if [ -f profile.out ]; then \
cat profile.out >> coverage.txt; \
rm profile.out; \
fi; \
done;
.PHONY: clean
clean: ## Cleanup any build binaries or packages.
@echo "==> $@"
$(RM) -r $(BINDIR)
$(RM) -r $(BUILDDIR)
define buildpretty
mkdir -p $(BUILDDIR)/$(1)/$(2);
GOOS=$(1) GOARCH=$(2) CGO_ENABLED=0 GO111MODULE=on go build \
-o $(BUILDDIR)/$(1)/$(2)/$(NAME) \
${GO_LDFLAGS_STATIC} ./cmd/$(NAME);
md5sum $(BUILDDIR)/$(1)/$(2)/$(NAME) > $(BUILDDIR)/$(1)/$(2)/$(NAME).md5;
sha256sum $(BUILDDIR)/$(1)/$(2)/$(NAME) > $(BUILDDIR)/$(1)/$(2)/$(NAME).sha256;
endef
.PHONY: cross
cross: ## Builds the cross-compiled binaries, creating a clean directory structure (eg. GOOS/GOARCH/binary)
@echo "+ $@"
$(foreach GOOSARCH,$(GOOSARCHES), $(call buildpretty,$(subst /,,$(dir $(GOOSARCH))),$(notdir $(GOOSARCH))))
define buildrelease
GOOS=$(1) GOARCH=$(2) CGO_ENABLED=0 GO111MODULE=on go build \
-o $(BUILDDIR)/$(NAME)-$(1)-$(2) \
${GO_LDFLAGS_STATIC} ./cmd/$(NAME);
md5sum $(BUILDDIR)/$(NAME)-$(1)-$(2) > $(BUILDDIR)/$(NAME)-$(1)-$(2).md5;
sha256sum $(BUILDDIR)/$(NAME)-$(1)-$(2) > $(BUILDDIR)/$(NAME)-$(1)-$(2).sha256;
endef
.PHONY: release
release: ## Builds the cross-compiled binaries, naming them in such a way for release (eg. binary-GOOS-GOARCH)
@echo "+ $@"
$(foreach GOOSARCH,$(GOOSARCHES), $(call buildrelease,$(subst /,,$(dir $(GOOSARCH))),$(notdir $(GOOSARCH))))
.PHONY: help
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'

35
README.md Normal file
View file

@ -0,0 +1,35 @@
<img height="200" src="./docs/logo.png" alt="logo" align="right" >
# Pomerium : identity-aware access proxy
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium)
[![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium)
[![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg?style=flat-square)](https://github.com/pomerium/pomerium/blob/master/LICENSE)
Pomerium is a tool for managing secure access to internal applications and resources.
Use Pomerium to:
- provide a unified ingress gateway to internal corporate applications.
- enforce dynamic access policies based on context, identity, and device state.
- aggregate logging and telemetry data.
To learn more about zero-trust / BeyondCorp, check out [awesome-zero-trust].
## Getting started
For instructions on getting started with Pomerium, see our getting started docs.
## To start developing Pomerium
Assuming you have a working [Go environment].
```sh
$ go get -d github.com/pomerium/pomerium
$ cd $GOPATH/src/github.com/pomerium/pomerium
$ make
$ source ./env # see env.example
$ ./bin/pomerium -debug
```
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
[Go environment]: https://golang.org/doc/install

1
VERSION Normal file
View file

@ -0,0 +1 @@
v0.0.1

View file

@ -0,0 +1,254 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate"
import (
"encoding/base64"
"errors"
"fmt"
"html/template"
"net/url"
"strings"
"time"
"github.com/pomerium/envconfig"
"github.com/pomerium/pomerium/authenticate/providers"
"github.com/pomerium/pomerium/internal/aead"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
)
// Options permits the configuration of the authentication service
type Options struct {
// e.g.
Host string `envconfig:"HOST"`
//
ProxyClientID string `envconfig:"PROXY_CLIENT_ID"`
ProxyClientSecret string `envconfig:"PROXY_CLIENT_SECRET"`
// Coarse authorization based on user email domain
EmailDomains []string `envconfig:"SSO_EMAIL_DOMAIN"`
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
// Session/Cookie management
CookieName string
CookieSecret string `envconfig:"COOKIE_SECRET"`
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE" default:"168h"`
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH" default:"1h"`
CookieSecure bool `envconfig:"COOKIE_SECURE" default:"true"`
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY" default:"true"`
AuthCodeSecret string `envconfig:"AUTH_CODE_SECRET"`
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL" default:"720h"`
// Authentication provider configuration vars
RedirectURL *url.URL `envconfig:"IDP_REDIRECT_URL" ` // e.g. auth.example.com/oauth/callback
ClientID string `envconfig:"IDP_CLIENT_ID"` // IdP ClientID
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"` // IdP Secret
Provider string `envconfig:"IDP_PROVIDER"` //Provider name e.g. "oidc","okta","google",etc
ProviderURL *url.URL `envconfig:"IDP_PROVIDER_URL"`
Scopes []string `envconfig:"IDP_SCOPE" default:"openid,email,profile"`
// todo(bdd) : can delete?`
ApprovalPrompt string `envconfig:"IDP_APPROVAL_PROMPT" default:"consent"`
RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"`
RequestTimeout time.Duration `envconfig:"REQUEST_TIMEOUT" default:"2s"`
}
var defaultOptions = &Options{
EmailDomains: []string{"*"},
CookieName: "_pomerium_authenticate",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(1) * time.Hour,
RequestTimeout: time.Duration(2) * time.Second,
SessionLifetimeTTL: time.Duration(720) * time.Hour,
ApprovalPrompt: "consent",
Scopes: []string{"openid", "email", "profile"},
}
// OptionsFromEnvConfig builds the authentication service's configuration
// options from provided environmental variables
func OptionsFromEnvConfig() (*Options, error) {
o := defaultOptions
if err := envconfig.Process("", o); err != nil {
return nil, err
}
return o, nil
}
// Validate checks to see if configuration values are valid for authentication service.
// The checks do not modify the internal state of the Option structure. Function returns
// on first error found.
func (o *Options) Validate() error {
if o.ProviderURL == nil {
return errors.New("missing setting: identity provider url")
}
if o.RedirectURL == nil {
return errors.New("missing setting: identity provider redirect url")
}
redirectPath := "/oauth2/callback"
if o.RedirectURL.Path != redirectPath {
return fmt.Errorf("setting redirect-url was %s path should be %s", o.RedirectURL.Path, redirectPath)
}
if o.ClientID == "" {
return errors.New("missing setting: client id")
}
if o.ClientSecret == "" {
return errors.New("missing setting: client secret")
}
if len(o.EmailDomains) == 0 {
return errors.New("missing setting email domain")
}
if len(o.ProxyRootDomains) == 0 {
return errors.New("missing setting: proxy root domain")
}
if o.ProxyClientID == "" {
return errors.New("missing setting: proxy client id")
}
if o.ProxyClientSecret == "" {
return errors.New("missing setting: proxy client secret")
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
if err != nil {
return fmt.Errorf("authenticate options: cookie secret invalid"+
"must be a base64-encoded, 256 bit key e.g. `head -c32 /dev/urandom | base64`"+
"got %q", err)
}
validCookieSecretLength := false
for _, i := range []int{32, 64} {
if len(decodedCookieSecret) == i {
validCookieSecretLength = true
}
}
if !validCookieSecretLength {
return fmt.Errorf("authenticate options: invalid cookie secret strength want 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
}
if o.CookieRefresh >= o.CookieExpire {
return fmt.Errorf("cookie_refresh (%s) must be less than cookie_expire (%s)", o.CookieRefresh.String(), o.CookieExpire.String())
}
return nil
}
// Authenticator stores all the information associated with proxying the request.
type Authenticator struct {
Validator func(string) bool
EmailDomains []string
ProxyRootDomains []string
Host string
CookieSecure bool
ProxyClientID string
ProxyClientSecret string
SessionLifetimeTTL time.Duration
decodedCookieSecret []byte
templates *template.Template
// sesion related
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
cipher aead.Cipher
redirectURL *url.URL
provider providers.Provider
}
// NewAuthenticator creates a Authenticator struct and applies the optional functions slice to the struct.
func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) (*Authenticator, error) {
if opts == nil {
return nil, errors.New("options cannot be nil")
}
if err := opts.Validate(); err != nil {
return nil, err
}
decodedAuthCodeSecret, err := base64.StdEncoding.DecodeString(opts.AuthCodeSecret)
if err != nil {
return nil, err
}
cipher, err := aead.NewMiscreantCipher([]byte(decodedAuthCodeSecret))
if err != nil {
return nil, err
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
if err != nil {
return nil, err
}
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
sessions.CreateMiscreantCookieCipher(decodedCookieSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
c.CookieHTTPOnly = opts.CookieHTTPOnly
c.CookieExpire = opts.CookieExpire
c.CookieSecure = opts.CookieSecure
return nil
})
if err != nil {
return nil, err
}
p := &Authenticator{
ProxyClientID: opts.ProxyClientID,
ProxyClientSecret: opts.ProxyClientSecret,
EmailDomains: opts.EmailDomains,
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
CookieSecure: opts.CookieSecure,
redirectURL: opts.RedirectURL,
templates: templates.New(),
csrfStore: cookieStore,
sessionStore: cookieStore,
cipher: cipher,
}
// p.ServeMux = p.Handler()
p.provider, err = newProvider(opts)
if err != nil {
return nil, err
}
// apply the option functions
for _, optFunc := range optionFuncs {
err := optFunc(p)
if err != nil {
return nil, err
}
}
return p, nil
}
func newProvider(opts *Options) (providers.Provider, error) {
pd := &providers.ProviderData{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
ApprovalPrompt: opts.ApprovalPrompt,
SessionLifetimeTTL: opts.SessionLifetimeTTL,
ProviderURL: opts.ProviderURL,
Scopes: opts.Scopes,
}
np, err := providers.New(opts.Provider, pd)
if err != nil {
return nil, err
}
return providers.NewSingleFlightProvider(np), nil
}
func dotPrependDomains(d []string) []string {
for i := range d {
if d[i] != "" && !strings.HasPrefix(d[i], ".") {
d[i] = fmt.Sprintf(".%s", d[i])
}
}
return d
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,329 @@
// Package circuit implements the Circuit Breaker pattern.
// https://docs.microsoft.com/en-us/azure/architecture/patterns/circuit-breaker
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
import (
"fmt"
"math"
"math/rand"
"sync"
"time"
"github.com/benbjohnson/clock"
)
// State is a type that represents a state of Breaker.
type State int
// These constants are states of Breaker.
const (
StateClosed State = iota
StateHalfOpen
StateOpen
)
type (
// ShouldTripFunc is a function that takes in a Counts and returns true if the circuit breaker should be tripped.
ShouldTripFunc func(Counts) bool
// ShouldResetFunc is a function that takes in a Counts and returns true if the circuit breaker should be reset.
ShouldResetFunc func(Counts) bool
// BackoffDurationFunc is a function that takes in a Counts and returns the backoff duration
BackoffDurationFunc func(Counts) time.Duration
// StateChangeHook is a function that represents a state change.
StateChangeHook func(prev, to State)
// BackoffHook is a function that represents backoff.
BackoffHook func(duration time.Duration, reset time.Time)
)
var (
// DefaultShouldTripFunc is a default ShouldTripFunc.
DefaultShouldTripFunc = func(counts Counts) bool {
// Trip into Open after three consecutive failures
return counts.ConsecutiveFailures >= 3
}
// DefaultShouldResetFunc is a default ShouldResetFunc.
DefaultShouldResetFunc = func(counts Counts) bool {
// Reset after three consecutive successes
return counts.ConsecutiveSuccesses >= 3
}
// DefaultBackoffDurationFunc is an exponential backoff function
DefaultBackoffDurationFunc = ExponentialBackoffDuration(time.Duration(100)*time.Second, time.Duration(500)*time.Millisecond)
)
// ErrOpenState is returned when the b state is open
type ErrOpenState struct{}
func (e *ErrOpenState) Error() string { return "circuit breaker is open" }
// ExponentialBackoffDuration returns a function that uses exponential backoff and full jitter
func ExponentialBackoffDuration(maxBackoff, baseTimeout time.Duration) func(Counts) time.Duration {
return func(counts Counts) time.Duration {
// Full Jitter from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
// sleep = random_between(0, min(cap, base * 2 ** attempt))
backoff := math.Min(float64(maxBackoff), float64(baseTimeout)*math.Exp2(float64(counts.ConsecutiveFailures)))
jittered := rand.Float64() * backoff
return time.Duration(jittered)
}
}
// String implements stringer interface.
func (s State) String() string {
switch s {
case StateClosed:
return "closed"
case StateHalfOpen:
return "half-open"
case StateOpen:
return "open"
default:
return fmt.Sprintf("unknown state: %d", s)
}
}
// Counts holds the numbers of requests and their successes/failures.
type Counts struct {
CurrentRequests int
ConsecutiveSuccesses int
ConsecutiveFailures int
}
func (c *Counts) onRequest() {
c.CurrentRequests++
}
func (c *Counts) afterRequest() {
c.CurrentRequests--
}
func (c *Counts) onSuccess() {
c.ConsecutiveSuccesses++
c.ConsecutiveFailures = 0
}
func (c *Counts) onFailure() {
c.ConsecutiveFailures++
c.ConsecutiveSuccesses = 0
}
func (c *Counts) clear() {
c.ConsecutiveSuccesses = 0
c.ConsecutiveFailures = 0
}
// Options configures Breaker:
//
// HalfOpenConcurrentRequests specifies how many concurrent requests to allow while
// the circuit is in the half-open state
//
// ShouldTripFunc specifies when the circuit should trip from the closed state to
// the open state. It takes a Counts struct and returns a bool.
//
// ShouldResetFunc specifies when the circuit should be reset from the half-open state
// to the closed state and allow all requests. It takes a Counts struct and returns a bool.
//
// BackoffDurationFunc specifies how long to set the backoff duration. It takes a
// counts struct and returns a time.Duration
//
// OnStateChange is called whenever the state of the Breaker changes.
//
// OnBackoff is called whenever a backoff is set with the backoff duration and reset time
//
// TestClock is used to mock the clock during tests
type Options struct {
HalfOpenConcurrentRequests int
ShouldTripFunc ShouldTripFunc
ShouldResetFunc ShouldResetFunc
BackoffDurationFunc BackoffDurationFunc
// hooks
OnStateChange StateChangeHook
OnBackoff BackoffHook
// used in tests
TestClock clock.Clock
}
// Breaker is a state machine to prevent sending requests that are likely to fail.
type Breaker struct {
halfOpenRequests int
shouldTripFunc ShouldTripFunc
shouldResetFunc ShouldResetFunc
backoffDurationFunc BackoffDurationFunc
// hooks
onStateChange StateChangeHook
onBackoff BackoffHook
// used primarly for mocking tests
clock clock.Clock
mutex sync.Mutex
state State
counts Counts
backoffExpires time.Time
generation int
}
// NewBreaker returns a new Breaker configured with the given Settings.
func NewBreaker(opts *Options) *Breaker {
b := new(Breaker)
if opts == nil {
opts = &Options{}
}
// set hooks
b.onStateChange = opts.OnStateChange
b.onBackoff = opts.OnBackoff
b.halfOpenRequests = 1
if opts.HalfOpenConcurrentRequests > 0 {
b.halfOpenRequests = opts.HalfOpenConcurrentRequests
}
b.backoffDurationFunc = DefaultBackoffDurationFunc
if opts.BackoffDurationFunc != nil {
b.backoffDurationFunc = opts.BackoffDurationFunc
}
b.shouldTripFunc = DefaultShouldTripFunc
if opts.ShouldTripFunc != nil {
b.shouldTripFunc = opts.ShouldTripFunc
}
b.shouldResetFunc = DefaultShouldResetFunc
if opts.ShouldResetFunc != nil {
b.shouldResetFunc = opts.ShouldResetFunc
}
b.clock = clock.New()
if opts.TestClock != nil {
b.clock = opts.TestClock
}
b.setState(StateClosed)
return b
}
// Call runs the given function if the Breaker allows the call.
// Call returns an error instantly if the Breaker rejects the request.
// Otherwise, Call returns the result of the request.
func (b *Breaker) Call(f func() (interface{}, error)) (interface{}, error) {
generation, err := b.beforeRequest()
if err != nil {
return nil, err
}
result, err := f()
b.afterRequest(err == nil, generation)
return result, err
}
func (b *Breaker) beforeRequest() (int, error) {
b.mutex.Lock()
defer b.mutex.Unlock()
state, generation := b.currentState()
switch state {
case StateOpen:
return generation, &ErrOpenState{}
case StateHalfOpen:
if b.counts.CurrentRequests >= b.halfOpenRequests {
return generation, &ErrOpenState{}
}
}
b.counts.onRequest()
return generation, nil
}
func (b *Breaker) afterRequest(success bool, prevGeneration int) {
b.mutex.Lock()
defer b.mutex.Unlock()
b.counts.afterRequest()
state, generation := b.currentState()
if prevGeneration != generation {
return
}
if success {
b.onSuccess(state)
return
}
b.onFailure(state)
}
func (b *Breaker) onSuccess(state State) {
b.counts.onSuccess()
switch state {
case StateHalfOpen:
if b.shouldResetFunc(b.counts) {
b.setState(StateClosed)
b.counts.clear()
}
}
}
func (b *Breaker) onFailure(state State) {
b.counts.onFailure()
switch state {
case StateClosed:
if b.shouldTripFunc(b.counts) {
b.setState(StateOpen)
b.counts.clear()
b.setBackoff()
}
case StateOpen:
b.setBackoff()
case StateHalfOpen:
b.setState(StateOpen)
b.setBackoff()
}
}
func (b *Breaker) setBackoff() {
backoffDuration := b.backoffDurationFunc(b.counts)
backoffExpires := b.clock.Now().Add(backoffDuration)
b.backoffExpires = backoffExpires
if b.onBackoff != nil {
b.onBackoff(backoffDuration, backoffExpires)
}
}
func (b *Breaker) currentState() (State, int) {
switch b.state {
case StateOpen:
if b.clock.Now().After(b.backoffExpires) {
b.setState(StateHalfOpen)
}
}
return b.state, b.generation
}
func (b *Breaker) newGeneration() {
b.generation++
}
func (b *Breaker) setState(state State) {
if b.state == state {
return
}
b.newGeneration()
prev := b.state
b.state = state
if b.onStateChange != nil {
b.onStateChange(prev, state)
}
}

View file

@ -0,0 +1,187 @@
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
import (
"errors"
"sync"
"testing"
"time"
"github.com/benbjohnson/clock"
)
var errFailed = errors.New("failed")
func fail() (interface{}, error) {
return nil, errFailed
}
func succeed() (interface{}, error) {
return nil, nil
}
func TestCircuitBreaker(t *testing.T) {
mock := clock.NewMock()
threshold := 3
timeout := time.Duration(2) * time.Second
trip := func(c Counts) bool { return c.ConsecutiveFailures > threshold }
reset := func(c Counts) bool { return c.ConsecutiveSuccesses > threshold }
backoff := func(c Counts) time.Duration { return timeout }
stateChange := func(p, c State) { t.Logf("state change from %s to %s\n", p, c) }
cb := NewBreaker(&Options{
TestClock: mock,
ShouldTripFunc: trip,
ShouldResetFunc: reset,
BackoffDurationFunc: backoff,
OnStateChange: stateChange,
})
state, _ := cb.currentState()
if state != StateClosed {
t.Fatalf("expected state to start %s, got %s", StateClosed, state)
}
for i := 0; i <= threshold; i++ {
_, err := cb.Call(fail)
if err == nil {
t.Fatalf("expected to error, got nil")
}
state, _ := cb.currentState()
t.Logf("iteration %#v", i)
if i == threshold {
// we expect this to be the case to trip the circuit
if state != StateOpen {
t.Fatalf("expected state to be %s, got %s", StateOpen, state)
}
} else if state != StateClosed {
// this is a normal failure case
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
}
}
_, err := cb.Call(fail)
switch err.(type) {
case *ErrOpenState:
// this is the expected case
break
default:
t.Errorf("%#v", cb.counts)
t.Fatalf("expected to get open state failure, got %s", err)
}
// we advance time by the timeout and a hair
mock.Add(timeout + time.Duration(1)*time.Millisecond)
state, _ = cb.currentState()
if state != StateHalfOpen {
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
}
for i := 0; i <= threshold; i++ {
_, err := cb.Call(succeed)
if err != nil {
t.Fatalf("expected to get no error, got %s", err)
}
state, _ := cb.currentState()
t.Logf("iteration %#v", i)
if i == threshold {
// we expect this to be the case that ressets the circuit
if state != StateClosed {
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
}
} else if state != StateHalfOpen {
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
}
}
state, _ = cb.currentState()
if state != StateClosed {
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
}
}
func TestExponentialBackOffFunc(t *testing.T) {
baseTimeout := time.Duration(1) * time.Millisecond
// Note Expected is an upper range case
cases := []struct {
FailureCount int
Expected time.Duration
}{
{
FailureCount: 0,
Expected: time.Duration(1) * time.Millisecond,
},
{
FailureCount: 1,
Expected: time.Duration(2) * time.Millisecond,
},
{
FailureCount: 2,
Expected: time.Duration(4) * time.Millisecond,
},
{
FailureCount: 3,
Expected: time.Duration(8) * time.Millisecond,
},
{
FailureCount: 4,
Expected: time.Duration(16) * time.Millisecond,
},
{
FailureCount: 5,
Expected: time.Duration(32) * time.Millisecond,
},
{
FailureCount: 6,
Expected: time.Duration(64) * time.Millisecond,
},
{
FailureCount: 7,
Expected: time.Duration(128) * time.Millisecond,
},
{
FailureCount: 8,
Expected: time.Duration(256) * time.Millisecond,
},
{
FailureCount: 9,
Expected: time.Duration(512) * time.Millisecond,
},
{
FailureCount: 10,
Expected: time.Duration(1024) * time.Millisecond,
},
}
f := ExponentialBackoffDuration(time.Duration(1)*time.Hour, baseTimeout)
for _, tc := range cases {
got := f(Counts{ConsecutiveFailures: tc.FailureCount})
t.Logf("got backoff %#v", got)
if got > tc.Expected {
t.Errorf("got %#v but expected less than %#v", got, tc.Expected)
}
}
}
func TestCircuitBreakerClosedParallel(t *testing.T) {
cb := NewBreaker(nil)
numReqs := 10000
wg := &sync.WaitGroup{}
routine := func(wg *sync.WaitGroup) {
for i := 0; i < numReqs; i++ {
cb.Call(succeed)
}
wg.Done()
}
numRoutines := 10
for i := 0; i < numRoutines; i++ {
wg.Add(1)
go routine(wg)
}
total := numReqs * numRoutines
wg.Wait()
if cb.counts.ConsecutiveSuccesses != total {
t.Fatalf("expected to get total requests %d, got %d", total, cb.counts.ConsecutiveSuccesses)
}
}

629
authenticate/handlers.go Normal file
View file

@ -0,0 +1,629 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate"
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/pomerium/pomerium/internal/aead"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
m "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
var securityHeaders = map[string]string{
"Strict-Transport-Security": "max-age=31536000",
"X-Frame-Options": "DENY",
"X-Content-Type-Options": "nosniff",
"X-XSS-Protection": "1; mode=block",
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
"Referrer-Policy": "Same-origin",
}
// Handler returns the Http.Handlers for authentication, callback, and refresh
func (p *Authenticator) Handler() http.Handler {
mux := http.NewServeMux()
// we setup global endpoints that should respond to any hostname
mux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
serviceMux := http.NewServeMux()
// standard rest and healthcheck endpoints
serviceMux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
serviceMux.HandleFunc("/robots.txt", m.WithMethods(p.RobotsTxt, "GET"))
// Identity Provider (IdP) endpoints and callbacks
serviceMux.HandleFunc("/start", m.WithMethods(p.OAuthStart, "GET"))
serviceMux.HandleFunc("/oauth2/callback", m.WithMethods(p.OAuthCallback, "GET"))
// authenticator-server endpoints, todo(bdd): make gRPC
serviceMux.HandleFunc("/sign_in", m.WithMethods(m.ValidateClientID(p.validateSignature(p.SignIn), p.ProxyClientID), "GET"))
serviceMux.HandleFunc("/sign_out", m.WithMethods(p.validateSignature(p.SignOut), "GET", "POST"))
serviceMux.HandleFunc("/profile", m.WithMethods(p.validateExisting(p.GetProfile), "GET"))
serviceMux.HandleFunc("/validate", m.WithMethods(p.validateExisting(p.ValidateToken), "GET"))
serviceMux.HandleFunc("/redeem", m.WithMethods(p.validateExisting(p.Redeem), "POST"))
serviceMux.HandleFunc("/refresh", m.WithMethods(p.validateExisting(p.Refresh), "POST"))
// NOTE: we have to include trailing slash for the router to match the host header
host := p.Host
if !strings.HasSuffix(host, "/") {
host = fmt.Sprintf("%s/", host)
}
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
return m.SetHeaders(mux, securityHeaders)
}
// validateSignature wraps a common collection of middlewares to validate signatures
func (p *Authenticator) validateSignature(f http.HandlerFunc) http.HandlerFunc {
return validateRedirectURI(validateSignature(f, p.ProxyClientSecret), p.ProxyRootDomains)
}
// validateSignature wraps a common collection of middlewares to validate
// a (presumably) existing user session
func (p *Authenticator) validateExisting(f http.HandlerFunc) http.HandlerFunc {
return m.ValidateClientID(m.ValidateClientSecret(f, p.ProxyClientSecret), p.ProxyClientID)
}
// RobotsTxt handles the /robots.txt route.
func (p *Authenticator) RobotsTxt(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
}
// PingPage handles the /ping route
func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
}
// SignInPage directs the user to the sign in page
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
requestLog := log.WithRequest(req, "authenticate.SignInPage")
rw.WriteHeader(code)
redirectURL := p.redirectURL.ResolveReference(req.URL)
// validateRedirectURI middleware already ensures that this is a valid URL
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
t := struct {
ProviderName string
EmailDomains []string
Redirect string
Destination string
Version string
}{
ProviderName: p.provider.Data().ProviderName,
EmailDomains: p.EmailDomains,
Redirect: redirectURL.String(),
Destination: destinationURL.Host,
Version: version.FullVersion(),
}
requestLog.Info().
Str("ProviderName", p.provider.Data().ProviderName).
Str("Redirect", redirectURL.String()).
Str("Destination", destinationURL.Host).
Str("EmailDomains", strings.Join(p.EmailDomains, ", ")).
Msg("authenticate.SignInPage")
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
}
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
requestLog := log.WithRequest(req, "authenticate.authenticate")
session, err := p.sessionStore.LoadSession(req)
if err != nil {
log.Error().Err(err).Msg("authenticate.authenticate")
p.sessionStore.ClearSession(rw, req)
return nil, err
}
// ensure sessions lifetime has not expired
if session.LifetimePeriodExpired() {
requestLog.Warn().Msg("lifetime expired")
p.sessionStore.ClearSession(rw, req)
return nil, sessions.ErrLifetimeExpired
}
// check if session refresh period is up
if session.RefreshPeriodExpired() {
ok, err := p.provider.RefreshSessionIfNeeded(session)
if err != nil {
requestLog.Error().Err(err).Msg("failed to refresh session")
p.sessionStore.ClearSession(rw, req)
return nil, err
}
if !ok {
requestLog.Error().Msg("user unauthorized after refresh")
p.sessionStore.ClearSession(rw, req)
return nil, httputil.ErrUserNotAuthorized
}
// update refresh'd session in cookie
err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
// We refreshed the session successfully, but failed to save it.
// This could be from failing to encode the session properly.
// But, we clear the session cookie and reject the request
requestLog.Error().Err(err).Msg("could not save refreshed session")
p.sessionStore.ClearSession(rw, req)
return nil, err
}
} else {
// The session has not exceeded it's lifetime or requires refresh
ok := p.provider.ValidateSessionState(session)
if !ok {
requestLog.Error().Msg("invalid session state")
p.sessionStore.ClearSession(rw, req)
return nil, httputil.ErrUserNotAuthorized
}
err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
requestLog.Error().Err(err).Msg("failed to save valid session")
p.sessionStore.ClearSession(rw, req)
return nil, err
}
}
if !p.Validator(session.Email) {
requestLog.Error().Msg("invalid email user")
return nil, httputil.ErrUserNotAuthorized
}
return session, nil
}
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
// and if the user is not authenticated, it renders a sign in page.
func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
// We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
// page.
//
// If the user is authenticated, we redirect back to the proxy application
// at the `redirect_uri`, with a temporary token.
//
// TODO: It is possible for a user to visit this page without a redirect destination.
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
session, err := p.authenticate(rw, req)
switch err {
case nil:
// User is authenticated, redirect back to the proxy application
// with the necessary state
p.ProxyOAuthRedirect(rw, req, session)
case http.ErrNoCookie:
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
p.SignInPage(rw, req, http.StatusOK)
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
p.sessionStore.ClearSession(rw, req)
p.SignInPage(rw, req, http.StatusOK)
default:
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
}
}
// ProxyOAuthRedirect redirects the user back to sso proxy's redirection endpoint.
func (p *Authenticator) ProxyOAuthRedirect(rw http.ResponseWriter, req *http.Request, session *sessions.SessionState) {
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
//
// We redirect the user back to the proxy application's redirection endpoint; in the
// sso proxy, this is the `/oauth/callback` endpoint.
//
// We must provide the proxy with a temporary authorization code via the `code` parameter,
// which they can use to redeem an access token for subsequent API calls.
//
// We must also include the original `state` parameter received from the proxy application.
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
return
}
state := req.Form.Get("state")
if state == "" {
httputil.ErrorResponse(rw, req, "no state parameter supplied", http.StatusForbidden)
return
}
redirectURI := req.Form.Get("redirect_uri")
if redirectURI == "" {
httputil.ErrorResponse(rw, req, "no redirect_uri parameter supplied", http.StatusForbidden)
return
}
redirectURL, err := url.Parse(redirectURI)
if err != nil {
httputil.ErrorResponse(rw, req, "malformed redirect_uri parameter passed", http.StatusBadRequest)
return
}
encrypted, err := sessions.MarshalSession(session, p.cipher)
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(rw, req, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound)
}
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
u, _ := url.Parse(redirectURL.String())
params, _ := url.ParseQuery(u.RawQuery)
params.Set("code", authCode)
params.Set("state", state)
u.RawQuery = params.Encode()
if u.Scheme == "" {
u.Scheme = "https"
}
return u.String()
}
// SignOut signs the user out.
func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) {
redirectURI := req.Form.Get("redirect_uri")
if req.Method == "GET" {
p.SignOutPage(rw, req, "")
return
}
session, err := p.sessionStore.LoadSession(req)
switch err {
case nil:
break
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
http.Redirect(rw, req, redirectURI, http.StatusFound)
return
default:
// a different error, clear the session cookie and redirect
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
p.sessionStore.ClearSession(rw, req)
http.Redirect(rw, req, redirectURI, http.StatusFound)
return
}
err = p.provider.Revoke(session)
if err != nil {
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
p.SignOutPage(rw, req, "An error occurred during sign out. Please try again.")
return
}
p.sessionStore.ClearSession(rw, req)
http.Redirect(rw, req, redirectURI, http.StatusFound)
}
// SignOutPage renders a sign out page with a message
func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, message string) {
// validateRedirectURI middleware already ensures that this is a valid URL
redirectURI := req.Form.Get("redirect_uri")
session, err := p.sessionStore.LoadSession(req)
if err != nil {
http.Redirect(rw, req, redirectURI, http.StatusFound)
return
}
signature := req.Form.Get("sig")
timestamp := req.Form.Get("ts")
destinationURL, _ := url.Parse(redirectURI)
// An error message indicates that an internal server error occurred
if message != "" {
rw.WriteHeader(http.StatusInternalServerError)
}
t := struct {
Redirect string
Signature string
Timestamp string
Message string
Destination string
Email string
Version string
}{
Redirect: redirectURI,
Signature: signature,
Timestamp: timestamp,
Message: message,
Destination: destinationURL.Host,
Email: session.Email,
Version: version.FullVersion(),
}
p.templates.ExecuteTemplate(rw, "sign_out.html", t)
return
}
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
nonce := fmt.Sprintf("%x", aead.GenerateKey())
p.csrfStore.SetCSRF(rw, req, nonce)
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
if err != nil || !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
proxyRedirectSig := authRedirectURL.Query().Get("sig")
ts := authRedirectURL.Query().Get("ts")
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.ProxyClientSecret) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
signInURL := p.provider.GetSignInURL(state)
http.Redirect(rw, req, signInURL, http.StatusFound)
}
func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, error) {
session, err := p.provider.Redeem(code)
if err != nil {
return nil, err
}
if session.Email == "" {
return nil, fmt.Errorf("no email included in session")
}
return session, nil
}
// getOAuthCallback completes the oauth cycle from an identity provider's callback
func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Request) (string, error) {
requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
// finish the oauth cycle
err := req.ParseForm()
if err != nil {
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
}
errorString := req.Form.Get("error")
if errorString != "" {
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
}
code := req.Form.Get("code")
if code == "" {
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
}
session, err := p.redeemCode(req.Host, code)
if err != nil {
requestLog.Error().Err(err).Msg("error redeeming authentication code")
return "", err
}
bytes, err := base64.URLEncoding.DecodeString(req.Form.Get("state"))
if err != nil {
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
}
s := strings.SplitN(string(bytes), ":", 2)
if len(s) != 2 {
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
}
nonce := s[0]
redirect := s[1]
c, err := p.csrfStore.GetCSRF(req)
if err != nil {
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
}
p.csrfStore.ClearCSRF(rw, req)
if c.Value != nonce {
requestLog.Error().Err(err).Msg("csrf token mismatch")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
}
if !validRedirectURI(redirect, p.ProxyRootDomains) {
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
}
// Set cookie, or deny: The authenticator validates the session email and group
// - for p.Validator see validator.go#newValidatorImpl for more info
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
if !p.Validator(session.Email) {
requestLog.Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
}
requestLog.Info().Str("email", session.Email).Msg("authentication complete")
err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
requestLog.Error().Err(err).Msg("internal error")
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
}
return redirect, nil
}
// OAuthCallback handles the callback from the provider, and returns an error response if there is an error.
// If there is no error it will redirect to the redirect url.
func (p *Authenticator) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.getOAuthCallback(rw, req)
switch h := err.(type) {
case nil:
break
case httputil.HTTPError:
httputil.ErrorResponse(rw, req, h.Message, h.Code)
return
default:
httputil.ErrorResponse(rw, req, "Internal Error", http.StatusInternalServerError)
return
}
http.Redirect(rw, req, redirect, http.StatusFound)
}
// Redeem has a signed access token, and provides the user information associated with the access token.
func (p *Authenticator) Redeem(rw http.ResponseWriter, req *http.Request) {
// The auth code is redeemed by the sso proxy for an access token, refresh token,
// expiration, and email.
requestLog := log.WithRequest(req, "authenticate.Redeem")
err := req.ParseForm()
if err != nil {
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
return
}
session, err := sessions.UnmarshalSession(req.Form.Get("code"), p.cipher)
if err != nil {
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
http.Error(rw, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
return
}
if session == nil {
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
http.Error(rw, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
return
}
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
requestLog.Error().Msg("expired session")
p.sessionStore.ClearSession(rw, req)
http.Error(rw, fmt.Sprintf("expired session"), http.StatusUnauthorized)
return
}
response := struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"`
Email string `json:"email"`
}{
AccessToken: session.AccessToken,
RefreshToken: session.RefreshToken,
IDToken: session.IDToken,
ExpiresIn: int64(session.RefreshDeadline.Sub(time.Now()).Seconds()),
Email: session.Email,
}
jsonBytes, err := json.Marshal(response)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
rw.Header().Set("GAP-Auth", session.Email)
rw.Header().Set("Content-Type", "application/json")
rw.Write(jsonBytes)
}
// Refresh takes a refresh token and returns a new access token
func (p *Authenticator) Refresh(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
return
}
refreshToken := req.Form.Get("refresh_token")
if refreshToken == "" {
http.Error(rw, "Bad Request: No Refresh Token", http.StatusBadRequest)
return
}
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
return
}
response := struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
}{
AccessToken: accessToken,
ExpiresIn: int64(expiresIn.Seconds()),
}
bytes, err := json.Marshal(response)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
rw.WriteHeader(http.StatusCreated)
rw.Header().Set("Content-Type", "application/json")
rw.Write(bytes)
}
// GetProfile gets a list of groups of which a user is a member.
func (p *Authenticator) GetProfile(rw http.ResponseWriter, req *http.Request) {
// The sso proxy sends the user's email to this endpoint to get a list of Google groups that
// the email is a member of. The proxy will compare these groups to the list of allowed
// groups for the upstream service the user is trying to access.
email := req.FormValue("email")
if email == "" {
http.Error(rw, "no email address included", http.StatusBadRequest)
return
}
// groupsFormValue := req.FormValue("groups")
// allowedGroups := []string{}
// if groupsFormValue != "" {
// allowedGroups = strings.Split(groupsFormValue, ",")
// }
// groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
// if err != nil {
// log.Error().Err(err).Msg("authenticate.GetProfile : error retrieving groups")
// httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
// return
// }
response := struct {
Email string `json:"email"`
}{
Email: email,
}
jsonBytes, err := json.Marshal(response)
if err != nil {
http.Error(rw, fmt.Sprintf("error marshaling response: %s", err.Error()), http.StatusInternalServerError)
return
}
rw.Header().Set("GAP-Auth", email)
rw.Header().Set("Content-Type", "application/json")
rw.Write(jsonBytes)
}
// ValidateToken validates the X-Access-Token from the header and returns an error response
// if it's invalid
func (p *Authenticator) ValidateToken(rw http.ResponseWriter, req *http.Request) {
accessToken := req.Header.Get("X-Access-Token")
idToken := req.Header.Get("X-Id-Token")
if accessToken == "" {
rw.WriteHeader(http.StatusBadRequest)
return
}
if idToken == "" {
rw.WriteHeader(http.StatusBadRequest)
return
}
ok := p.provider.ValidateSessionState(&sessions.SessionState{
AccessToken: accessToken,
IDToken: idToken,
})
if !ok {
rw.WriteHeader(http.StatusUnauthorized)
return
}
rw.WriteHeader(http.StatusOK)
return
}

View file

@ -0,0 +1,98 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate"
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/httputil"
)
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
// the url's domain is one in the list of proxy root domains.
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
if !validRedirectURI(redirectURI, proxyRootDomains) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
func validRedirectURI(uri string, rootDomains []string) bool {
redirectURL, err := url.Parse(uri)
if uri == "" || err != nil || redirectURL.Host == "" {
return false
}
for _, domain := range rootDomains {
if strings.HasSuffix(redirectURL.Hostname(), domain) {
return true
}
}
return false
}
func validateSignature(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
sigVal := req.Form.Get("sig")
timestamp := req.Form.Get("ts")
if !validSignature(redirectURI, sigVal, timestamp, proxyClientSecret) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := url.Parse(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
i, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return false
}
tm := time.Unix(i, 0)
ttl := 5 * time.Minute
if time.Now().Sub(tm) > ttl {
return false
}
localSig := redirectURLSignature(redirectURI, tm, secret)
return hmac.Equal(requestSig, localSig)
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return h.Sum(nil)
}

View file

@ -0,0 +1,100 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"net/url"
"time"
oidc "github.com/pomerium/go-oidc"
"github.com/pomerium/pomerium/authenticate/circuit"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
"golang.org/x/oauth2"
)
// GoogleProvider is an implementation of the Provider interface.
type GoogleProvider struct {
*ProviderData
cb *circuit.Breaker
// non-standard oidc fields
RevokeURL *url.URL
}
// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
func NewGoogleProvider(p *ProviderData) (*GoogleProvider, error) {
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
if err != nil {
return nil, err
}
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
p.oauth = &oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: p.RedirectURL.String(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
googleProvider := &GoogleProvider{
ProviderData: p,
}
// google supports a revokation endpoint
var claims struct {
RevokeURL string `json:"revocation_endpoint"`
}
if err := provider.Claims(&claims); err != nil {
return nil, err
}
googleProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
if err != nil {
return nil, err
}
googleProvider.cb = circuit.NewBreaker(&circuit.Options{
HalfOpenConcurrentRequests: 2,
OnStateChange: googleProvider.cbStateChange,
OnBackoff: googleProvider.cbBackoff,
ShouldTripFunc: func(c circuit.Counts) bool { return c.ConsecutiveFailures >= 3 },
ShouldResetFunc: func(c circuit.Counts) bool { return c.ConsecutiveSuccesses >= 6 },
BackoffDurationFunc: circuit.ExponentialBackoffDuration(
time.Duration(200)*time.Second,
time.Duration(500)*time.Millisecond),
})
return googleProvider, nil
}
func (p *GoogleProvider) cbBackoff(duration time.Duration, reset time.Time) {
log.Info().Dur("duration", duration).Msg("authenticate/providers/google.cbBackoff")
}
func (p *GoogleProvider) cbStateChange(from, to circuit.State) {
log.Info().Str("from", from.String()).Str("to", to.String()).Msg("authenticate/providers/google.cbStateChange")
}
// Revoke revokes the access token a given session state.
//
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
// https://github.com/googleapis/google-api-dotnet-client/issues/1285
func (p *GoogleProvider) Revoke(s *sessions.SessionState) error {
params := url.Values{}
params.Add("token", s.AccessToken)
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
return err
}
return nil
}
// GetSignInURL returns the sign in url with typical oauth parameters
// Google requires access type offline
func (p *GoogleProvider) GetSignInURL(state string) string {
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
}

View file

@ -0,0 +1,32 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
)
// OIDCProvider provides a standard, OpenID Connect implementation
// of an authorization identity provider.
type OIDCProvider struct {
*ProviderData
}
// NewOIDCProvider creates a new instance of an OpenID Connect provider.
func NewOIDCProvider(p *ProviderData) (*OIDCProvider, error) {
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
if err != nil {
return nil, err
}
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
p.oauth = &oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: p.RedirectURL.String(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
return &OIDCProvider{ProviderData: p}, nil
}

View file

@ -0,0 +1,69 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"net/url"
oidc "github.com/pomerium/go-oidc"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
"golang.org/x/oauth2"
)
// OktaProvider provides a standard, OpenID Connect implementation
// of an authorization identity provider.
type OktaProvider struct {
*ProviderData
// non-standard oidc fields
RevokeURL *url.URL
}
// NewOktaProvider creates a new instance of an OpenID Connect provider.
func NewOktaProvider(p *ProviderData) (*OktaProvider, error) {
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, "https://dev-108295.oktapreview.com/oauth2/default")
if err != nil {
return nil, err
}
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
p.oauth = &oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: p.RedirectURL.String(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
oktaProvider := OktaProvider{ProviderData: p}
// okta supports a revokation endpoint
var claims struct {
RevokeURL string `json:"revocation_endpoint"`
}
if err := provider.Claims(&claims); err != nil {
return nil, err
}
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
if err != nil {
return nil, err
}
return &oktaProvider, nil
}
// Revoke revokes the access token a given session state.
// https://developer.okta.com/docs/api/resources/oidc#revoke
func (p *OktaProvider) Revoke(s *sessions.SessionState) error {
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("token", s.IDToken)
params.Add("token_type_hint", "refresh_token")
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
return err
}
return nil
}

View file

@ -0,0 +1,256 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"errors"
"fmt"
"net/url"
"time"
oidc "github.com/pomerium/go-oidc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"golang.org/x/oauth2"
)
const (
// GoogleProviderName identifies the Google provider
GoogleProviderName = "google"
// OIDCProviderName identifes a generic OpenID connect provider
OIDCProviderName = "oidc"
// OktaProviderName identifes the Okta identity provider
OktaProviderName = "okta"
)
// Provider is an interface exposing functions necessary to authenticate with a given provider.
type Provider interface {
Data() *ProviderData
Redeem(string) (*sessions.SessionState, error)
ValidateSessionState(*sessions.SessionState) bool
GetSignInURL(state string) string
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
Revoke(*sessions.SessionState) error
RefreshAccessToken(string) (string, time.Duration, error)
// Stop()
}
// New returns a new identity provider based on available name.
// Defaults to google.
func New(provider string, p *ProviderData) (Provider, error) {
switch provider {
case OIDCProviderName:
p, err := NewOIDCProvider(p)
if err != nil {
return nil, err
}
return p, nil
case OktaProviderName:
log.Info().Msg("Okta!")
p, err := NewOktaProvider(p)
if err != nil {
return nil, err
}
return p, nil
default:
p, err := NewGoogleProvider(p)
if err != nil {
return nil, err
}
return p, nil
}
}
// ProviderData holds the fields associated with providers
// necessary to implement the Provider interface.
type ProviderData struct {
RedirectURL *url.URL
ProviderName string
ClientID string
ClientSecret string
ProviderURL *url.URL
Scopes []string
ApprovalPrompt string
SessionLifetimeTTL time.Duration
verifier *oidc.IDTokenVerifier
oauth *oauth2.Config
}
// Data returns a ProviderData.
func (p *ProviderData) Data() *ProviderData { return p }
// GetSignInURL returns the sign in url with typical oauth parameters
func (p *ProviderData) GetSignInURL(state string) string {
return p.oauth.AuthCodeURL(state)
}
// ValidateSessionState validates a given session's from it's JWT token
// The function verifies it's been signed by the provider, preforms
// any additional checks depending on the Config, and returns the payload.
//
// ValidateSessionState does NOT do nonce validation.
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background()
_, err := p.verifier.Verify(ctx, s.IDToken)
if err != nil {
log.Error().Err(err).Msg("authenticate/providers.ValidateSessionState : failed to verify session state")
return false
}
return true
}
// Redeem creates a session with an identity provider from a authorization code
func (p *ProviderData) Redeem(code string) (*sessions.SessionState, error) {
ctx := context.Background()
// convert authorization code into a token
token, err := p.oauth.Exchange(ctx, code)
if err != nil {
log.Error().Err(err).Msg("authenticate/providers.Redeem : token exchange failed")
return nil, fmt.Errorf("token exchange: %v", err)
}
s, err := p.createSessionState(ctx, token)
if err != nil {
log.Error().Err(err).Msg("authenticate/providers.Redeem : unable to update session")
return nil, fmt.Errorf("unable to update session: %v", err)
}
return s, nil
}
// RefreshSessionIfNeeded will refresh the session state if it's deadline is expired
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if !sessionRefreshRequired(s) {
log.Info().Msg("authenticate/providers.RefreshSessionIfNeeded : session refresh not needed")
return false, nil
}
origExpiration := s.RefreshDeadline
err := p.redeemRefreshToken(s)
if err != nil {
log.Error().Err(err).Msg("authenticate/providers.RefreshSession")
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
log.Info().Msgf("authenticate/providers.Redeem refreshed id token %s (expired on %s)", s, origExpiration)
return true, nil
}
func (p *ProviderData) redeemRefreshToken(s *sessions.SessionState) error {
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 1")
ctx := context.Background()
t := &oauth2.Token{
RefreshToken: s.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 3")
// returns a TokenSource automatically refreshing it as necessary using the provided context
token, err := p.oauth.TokenSource(ctx, t).Token()
if err != nil {
log.Error().Err(err).Msg("authenticate/providers failed to get token")
return fmt.Errorf("failed to get token: %v", err)
}
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 4")
newSession, err := p.createSessionState(ctx, token)
if err != nil {
log.Error().Err(err).Msg("authenticate/providers unable to update session")
return fmt.Errorf("unable to update session: %v", err)
}
s.AccessToken = newSession.AccessToken
s.IDToken = newSession.IDToken
s.RefreshToken = newSession.RefreshToken
s.RefreshDeadline = newSession.RefreshDeadline
s.Email = newSession.Email
log.Info().
Str("AccessToken", s.AccessToken).
Str("IdToken", s.IDToken).
Time("RefreshDeadline", s.RefreshDeadline).
Str("RefreshToken", s.RefreshToken).
Str("Email", s.Email).
Msg("authenticate/providers.redeemRefreshToken")
return nil
}
func (p *ProviderData) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("token response did not contain an id_token")
}
log.Info().
Bool("ctx", ctx == nil).
Bool("Verifier", p.verifier == nil).
Str("rawIDToken", rawIDToken).
Msg("authenticate/providers.oidc.createSessionState 2")
// Parse and verify ID Token payload.
idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Error().Err(err).Msg("authenticate/providers could not verify id_token")
return nil, fmt.Errorf("could not verify id_token: %v", err)
}
// Extract custom claims.
var claims struct {
Email string `json:"email"`
Verified *bool `json:"email_verified"`
}
// parse claims from the raw, encoded jwt token
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse id_token claims: %v", err)
}
if claims.Email == "" {
return nil, fmt.Errorf("id_token did not contain an email")
}
if claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
return &sessions.SessionState{
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
RefreshDeadline: token.Expiry,
LifetimeDeadline: token.Expiry,
Email: claims.Email,
}, nil
}
// RefreshAccessToken allows the service to refresh an access token without
// prompting the user for permission.
func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
if refreshToken == "" {
return "", 0, errors.New("missing refresh token")
}
ctx := context.Background()
c := oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Endpoint: oauth2.Endpoint{TokenURL: p.ProviderURL.String()},
}
t := oauth2.Token{RefreshToken: refreshToken}
ts := c.TokenSource(ctx, &t)
log.Info().
Str("RefreshToken", refreshToken).
Msg("authenticate/providers.RefreshAccessToken")
newToken, err := ts.Token()
if err != nil {
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
return "", 0, err
}
return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil
}
// Revoke enables a user to revoke her tokenn. Though many providers such as
// google and okta provide revoke endpoints, since it's not officially supported
// as part of OpenID Connect, the default implementation throws an error.
func (p *ProviderData) Revoke(s *sessions.SessionState) error {
return errors.New("revoke not implemented")
}
func sessionRefreshRequired(s *sessions.SessionState) bool {
return s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == ""
}

View file

@ -0,0 +1,142 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"errors"
"fmt"
"time"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/singleflight"
)
var (
_ Provider = &SingleFlightProvider{}
)
// ErrUnexpectedReturnType is an error for an unexpected return type
var (
ErrUnexpectedReturnType = errors.New("received unexpected return type from single flight func call")
)
// SingleFlightProvider middleware provider that multiple requests for the same object
// to be processed as a single request. This is often called request collpasing or coalesce.
// This middleware leverages the golang singlelflight provider, with modifications for metrics.
//
// It's common among HTTP reverse proxy cache servers such as nginx, Squid or Varnish - they all call it something else but works similarly.
//
// * https://www.varnish-cache.org/docs/3.0/tutorial/handling_misbehaving_servers.html
// * http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_cache_lock
// * http://wiki.squid-cache.org/Features/CollapsedForwarding
type SingleFlightProvider struct {
provider Provider
single *singleflight.Group
}
// NewSingleFlightProvider returns a new SingleFlightProvider
func NewSingleFlightProvider(provider Provider) *SingleFlightProvider {
return &SingleFlightProvider{
provider: provider,
single: &singleflight.Group{},
}
}
func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, error)) (interface{}, error) {
compositeKey := fmt.Sprintf("%s/%s", endpoint, key)
resp, _, err := p.single.Do(compositeKey, fn)
return resp, err
}
// Data returns the provider data
func (p *SingleFlightProvider) Data() *ProviderData {
return p.provider.Data()
}
// Redeem wraps the provider's Redeem function.
func (p *SingleFlightProvider) Redeem(code string) (*sessions.SessionState, error) {
return p.provider.Redeem(code)
}
// ValidateSessionState wraps the provider's ValidateSessionState in a single flight call.
func (p *SingleFlightProvider) ValidateSessionState(s *sessions.SessionState) bool {
response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) {
valid := p.provider.ValidateSessionState(s)
return valid, nil
})
if err != nil {
return false
}
valid, ok := response.(bool)
if !ok {
return false
}
return valid
}
// GetSignInURL calls the provider's GetSignInURL function.
func (p *SingleFlightProvider) GetSignInURL(finalRedirect string) string {
return p.provider.GetSignInURL(finalRedirect)
}
// RefreshSessionIfNeeded wraps the provider's RefreshSessionIfNeeded function in a single flight
// call.
func (p *SingleFlightProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
response, err := p.do("RefreshSessionIfNeeded", s.RefreshToken, func() (interface{}, error) {
return p.provider.RefreshSessionIfNeeded(s)
})
if err != nil {
return false, err
}
r, ok := response.(bool)
if !ok {
return false, ErrUnexpectedReturnType
}
return r, nil
}
// Revoke wraps the provider's Revoke function in a single flight call.
func (p *SingleFlightProvider) Revoke(s *sessions.SessionState) error {
_, err := p.do("Revoke", s.AccessToken, func() (interface{}, error) {
err := p.provider.Revoke(s)
return nil, err
})
return err
}
// RefreshAccessToken wraps the provider's RefreshAccessToken function in a single flight call.
func (p *SingleFlightProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
type Response struct {
AccessToken string
ExpiresIn time.Duration
}
response, err := p.do("RefreshAccessToken", refreshToken, func() (interface{}, error) {
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
if err != nil {
return nil, err
}
return &Response{
AccessToken: accessToken,
ExpiresIn: expiresIn,
}, nil
})
if err != nil {
return "", 0, err
}
r, ok := response.(*Response)
if !ok {
return "", 0, ErrUnexpectedReturnType
}
return r.AccessToken, r.ExpiresIn, nil
}
// // Stop calls the provider's stop function
// func (p *SingleFlightProvider) Stop() {
// p.provider.Stop()
// }

View file

@ -0,0 +1,81 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"net/url"
"time"
"github.com/pomerium/pomerium/internal/sessions"
)
// TestProvider is a test implementation of the Provider interface.
type TestProvider struct {
*ProviderData
ValidToken bool
ValidGroup bool
SignInURL string
Refresh bool
RefreshFunc func(string) (string, time.Duration, error)
RefreshError error
Session *sessions.SessionState
RedeemError error
RevokeError error
Groups []string
GroupsError error
GroupsCall int
}
// NewTestProvider creates a new mock test provider.
func NewTestProvider(providerURL *url.URL) *TestProvider {
return &TestProvider{
ProviderData: &ProviderData{
ProviderName: "Test Provider",
ProviderURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/authorize",
},
},
}
}
// ValidateSessionState returns the mock provider's ValidToken field value.
func (tp *TestProvider) ValidateSessionState(*sessions.SessionState) bool {
return tp.ValidToken
}
// GetSignInURL returns the mock provider's SignInURL field value.
func (tp *TestProvider) GetSignInURL(finalRedirect string) string {
return tp.SignInURL
}
// RefreshSessionIfNeeded returns the mock provider's Refresh value, or an error.
func (tp *TestProvider) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) {
return tp.Refresh, tp.RefreshError
}
// RefreshAccessToken returns the mock provider's refresh access token information
func (tp *TestProvider) RefreshAccessToken(s string) (string, time.Duration, error) {
return tp.RefreshFunc(s)
}
// Revoke returns nil
func (tp *TestProvider) Revoke(*sessions.SessionState) error {
return tp.RevokeError
}
// ValidateGroupMembership returns the mock provider's GroupsError if not nil, or the Groups field value.
func (tp *TestProvider) ValidateGroupMembership(string, []string) ([]string, error) {
return tp.Groups, tp.GroupsError
}
// Redeem returns the mock provider's Session and RedeemError field value.
func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
return tp.Session, tp.RedeemError
}
// Stop fulfills the Provider interface
func (tp *TestProvider) Stop() {
return
}

12
authorize/README.md Normal file
View file

@ -0,0 +1,12 @@
# Authorize
## What's this package do?
The authorize packages makes a binary determination of access.
Authorization is on trust from:
- Device state (vulnerability scanned?, MDM?, BYOD? Encrypted?)
- User standing (HR status, Groups, etc)
- Context (time, location, role)
Driven by:
- Dynamic "policy as code", fine grained policy
- Machine Learning & anomaly detection based on multiple input sources

71
cmd/pomerium/main.go Normal file
View file

@ -0,0 +1,71 @@
package main
import (
"flag"
"fmt"
"net/http"
"os"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/internal/https"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/options"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/proxy"
)
var (
debugFlag = flag.Bool("debug", false, "run server in debug mode, changes log output to STDOUT and level to info")
versionFlag = flag.Bool("version", false, "prints the version")
)
func main() {
flag.Parse()
if *debugFlag {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout})
}
if *versionFlag {
fmt.Printf("%s", version.FullVersion())
os.Exit(0)
}
log.Info().Str("version", version.FullVersion()).Str("user-agent", version.UserAgent()).Msg("cmd/pomerium")
authOpts, err := authenticate.OptionsFromEnvConfig()
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium : failed to parse authenticator settings")
}
emailValidator := func(p *authenticate.Authenticator) error {
p.Validator = options.NewEmailValidator(authOpts.EmailDomains)
return nil
}
authenticator, err := authenticate.NewAuthenticator(authOpts, emailValidator)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium : failed to create authenticator")
}
proxyOpts, err := proxy.OptionsFromEnvConfig()
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium : failed to parse proxy settings")
}
validator := func(p *proxy.Proxy) error {
p.EmailValidator = options.NewEmailValidator(proxyOpts.EmailDomains)
return nil
}
p, err := proxy.NewProxy(proxyOpts, validator)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium : failed to create proxy")
}
// proxyHandler := log.NewLoggingHandler(p.Handler())
authHandler := http.TimeoutHandler(authenticator.Handler(), authOpts.RequestTimeout, "")
topMux := http.NewServeMux()
topMux.Handle(authOpts.Host+"/", authHandler)
topMux.Handle("/", p.Handler())
log.Fatal().Err(https.ListenAndServeTLS(nil, topMux))
}

BIN
docs/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

26
env.example Normal file
View file

@ -0,0 +1,26 @@
#!/bin/bash
export HOST="sso-auth.corp.beyondperimeter.com"
export REDIRECT_URL="https://sso-auth.corp.beyondperimeter.com/oauth2/callback"
export PROXY_ROOT_DOMAIN=beyondperimeter.com
export PROXY_CLIENT_ID=WLgwUNIJW6DtsnAM2ck510znU2T3l+WufPg67e50oVM=
export PROXY_CLIENT_SECRET=gFB0qsSxxPqCtoNMuF7Q1VupJSNEq0BguxlUfT0PE+Y=
# Generate 256 bitrandom key to encrypt the cookie `head -c32 /dev/urandom | base64`
export AUTH_CODE_SECRET=9wiTZq4qvmS/plYQyvzGKWPlH/UBy0DMYMA2x/zngrM=
export COOKIE_SECRET=uPGHo1ujND/k3B9V6yr52Gweq3RRYfFho98jxDG5Br8=
export COOKIE_SECURE=true
# Valid email domains
export EMAIL_DOMAIN=*
export SSO_EMAIL_DOMAIN=*
# IdP configuration
export IDP_PROVIDER="google"
export IDP_PROVIDER_URL="https://sso-auth.corp.beyondperimeter.com"
export IDP_CLIENT_ID="xxx.apps.googleusercontent.com"
export IDP_CLIENT_SECRET="xxx"
export IDP_REDIRECT_URL="https://sso-auth.corp.beyondperimeter.com/oauth2/callback"
# proxy'd routes
export ROUTES='news.corp.beyondperimeter.com':'news.ycombinator.com','github.corp.beyondperimeter.com':'github.com'

19
go.mod Normal file
View file

@ -0,0 +1,19 @@
module github.com/pomerium/pomerium
require (
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d
github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/rs/zerolog v1.11.0
github.com/stretchr/testify v1.2.2 // indirect
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
google.golang.org/appengine v1.4.0 // indirect
gopkg.in/square/go-jose.v2 v2.2.1 // indirect
)

34
go.sum Normal file
View file

@ -0,0 +1,34 @@
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b h1:VPhrxAgvd0d0xSZP4P8zzWZntso2NE0m69dmDEXM53I=
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b/go.mod h1:Vj6lPE3LxPymcFxg7hm9aDIJWCyhJMnxSNC/y9ZHtN8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d h1:iAavGhspmmDa8LCcQLpHV/JBPWKL7EfvMN7YrGHBN3o=
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d/go.mod h1:1Kz8Ca8PhJDtLYqgvbDZGn6GsJCvrT52SxQ3sPNJkDc=
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic=
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=

172
internal/aead/aead.go Normal file
View file

@ -0,0 +1,172 @@
package aead // import "github.com/pomerium/pomerium/internal/aead"
import (
"bytes"
"compress/gzip"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"sync"
miscreant "github.com/miscreant/miscreant-go"
)
const miscreantNonceSize = 16
var algorithmType = "AES-CMAC-SIV"
// Cipher provides methods to encrypt and decrypt values.
type Cipher interface {
Encrypt([]byte) ([]byte, error)
Decrypt([]byte) ([]byte, error)
Marshal(interface{}) (string, error)
Unmarshal(string, interface{}) error
}
// MiscreantCipher provides methods to encrypt and decrypt values.
// Using an AEAD is a cipher providing authenticated encryption with associated data.
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
type MiscreantCipher struct {
aead cipher.AEAD
mux sync.Mutex
}
// NewMiscreantCipher returns a new AES Cipher for encrypting values
func NewMiscreantCipher(secret []byte) (*MiscreantCipher, error) {
aead, err := miscreant.NewAEAD(algorithmType, secret, miscreantNonceSize)
if err != nil {
return nil, err
}
return &MiscreantCipher{
aead: aead,
}, nil
}
// GenerateKey wraps miscreant's GenerateKey function
func GenerateKey() []byte {
return miscreant.GenerateKey(32)
}
// Encrypt a value using AES-CMAC-SIV
func (c *MiscreantCipher) Encrypt(plaintext []byte) (joined []byte, err error) {
c.mux.Lock()
defer c.mux.Unlock()
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("miscreant error encrypting bytes: %v", r)
}
}()
nonce := miscreant.GenerateNonce(c.aead)
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
// we return the nonce as part of the returned value
joined = append(ciphertext[:], nonce[:]...)
return joined, nil
}
// Decrypt a value using AES-CMAC-SIV
func (c *MiscreantCipher) Decrypt(joined []byte) ([]byte, error) {
c.mux.Lock()
defer c.mux.Unlock()
if len(joined) <= miscreantNonceSize {
return nil, fmt.Errorf("invalid input size: %d", len(joined))
}
// grab out the nonce
pivot := len(joined) - miscreantNonceSize
ciphertext := joined[:pivot]
nonce := joined[pivot:]
plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
// and base64 encodes the binary value as a string and returns the result
func (c *MiscreantCipher) Marshal(s interface{}) (string, error) {
// encode json value
plaintext, err := json.Marshal(s)
if err != nil {
return "", err
}
compressed, err := compress(plaintext)
if err != nil {
return "", err
}
// encrypt the JSON
ciphertext, err := c.Encrypt(compressed)
if err != nil {
return "", err
}
// base64-encode the result
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
return encoded, nil
}
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice the pased cipher, and unmarshals the resulting JSON into the struct pointer passed
func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error {
// convert base64 string value to bytes
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return err
}
// decrypt the bytes
compressed, err := c.Decrypt(ciphertext)
if err != nil {
return err
}
// decompress
plaintext, err := decompress(compressed)
if err != nil {
return err
}
// unmarshal bytes
err = json.Unmarshal(plaintext, s)
if err != nil {
return err
}
return nil
}
func compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("aead/compress: failed to create a gzip writer: %q", err)
}
if writer == nil {
return nil, fmt.Errorf("aead/compress: failed to create a gzip writer")
}
if _, err = writer.Write(data); err != nil {
return nil, fmt.Errorf("aead/compress: failed to compress data with err: %q", err)
}
if err = writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func decompress(data []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("aead/compress: failed to create a gzip reader: %q", err)
}
defer reader.Close()
var buf bytes.Buffer
if _, err = io.Copy(&buf, reader); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

173
internal/aead/aead_test.go Normal file
View file

@ -0,0 +1,173 @@
package aead // import "github.com/pomerium/pomerium/internal/aead"
import (
"crypto/rand"
"crypto/sha1"
"fmt"
"reflect"
"sync"
"testing"
)
func TestEncodeAndDecodeAccessToken(t *testing.T) {
plaintext := []byte("my plain text value")
key := GenerateKey()
c, err := NewMiscreantCipher([]byte(key))
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
ciphertext, err := c.Encrypt(plaintext)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if reflect.DeepEqual(plaintext, ciphertext) {
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
}
got, err := c.Decrypt(ciphertext)
if err != nil {
t.Fatalf("unexpected err decrypting: %v", err)
}
if !reflect.DeepEqual(got, plaintext) {
t.Logf(" got: %v", got)
t.Logf("want: %v", plaintext)
t.Fatal("got unexpected decrypted value")
}
}
func TestMarshalAndUnmarshalStruct(t *testing.T) {
key := GenerateKey()
c, err := NewMiscreantCipher([]byte(key))
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
type TC struct {
Field string `json:"field"`
}
tc := &TC{
Field: "my plain text value",
}
value1, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
value2, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if value1 == value2 {
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
}
got1 := &TC{}
err = c.Unmarshal(value1, got1)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, tc) {
t.Logf("want: %#v", tc)
t.Logf(" got: %#v", got1)
t.Fatalf("expected structs to be equal")
}
got2 := &TC{}
err = c.Unmarshal(value2, got2)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, got2) {
t.Logf("got2: %#v", got2)
t.Logf("got1: %#v", got1)
t.Fatalf("expected structs to be equal")
}
}
// TestCipherDataRace exercises a simple concurrency test for the MicscreantCipher.
// In https://github.com/pomerium/pomerium/pull/75 we investigated why, on random occasion,
// unmarshalling session states would fail, triggering users to get kicked out of
// authenticated states. We narrowed our investigation to a data race we uncovered
// from our misuse of the underlying miscreant library which makes no attempt
// at thread-safety.
//
// In https://github.com/pomerium/pomerium/pull/77 we added this test to exercise the
// data race condition and resolved said race by introducing a simple mutex.
func TestCipherDataRace(t *testing.T) {
miscreantCipher, err := NewMiscreantCipher(GenerateKey())
if err != nil {
t.Fatalf("unexpected generating cipher err: %v", err)
}
type TC struct {
Field string `json:"field"`
}
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func(c *MiscreantCipher, wg *sync.WaitGroup) {
defer wg.Done()
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
t.Fatalf("unexecpted error reading random bytes: %v", err)
}
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
tc := &TC{
Field: sha,
}
value1, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
value2, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if value1 == value2 {
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
}
got1 := &TC{}
err = c.Unmarshal(value1, got1)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, tc) {
t.Logf("want: %#v", tc)
t.Logf(" got: %#v", got1)
t.Fatalf("expected structs to be equal")
}
got2 := &TC{}
err = c.Unmarshal(value2, got2)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, got2) {
t.Logf("got2: %#v", got2)
t.Logf("got1: %#v", got1)
t.Fatalf("expected structs to be equal")
}
}(miscreantCipher, wg)
}
wg.Wait()
}

View file

@ -0,0 +1,34 @@
package aead // import "github.com/pomerium/pomerium/internal/aead"
import (
"encoding/json"
)
// MockCipher is a mock of the cipher interface
type MockCipher struct {
MarshalError error
MarshalString string
UnmarshalError error
UnmarshalBytes []byte
}
// Encrypt returns an empty byte array and nil
func (mc *MockCipher) Encrypt([]byte) ([]byte, error) {
return []byte{}, nil
}
// Decrypt returns an empty byte array and nil
func (mc *MockCipher) Decrypt([]byte) ([]byte, error) {
return []byte{}, nil
}
// Marshal returns the marshal string and marsha error
func (mc *MockCipher) Marshal(interface{}) (string, error) {
return mc.MarshalString, mc.MarshalError
}
// Unmarshal unmarshals the unmarshal bytes to be set in s and returns the unmarshal error
func (mc *MockCipher) Unmarshal(b string, s interface{}) error {
json.Unmarshal(mc.UnmarshalBytes, s)
return mc.UnmarshalError
}

View file

@ -0,0 +1,36 @@
## Generating random seeds
In order of preference:
- `head -c32 /dev/urandom | base64`
- `openssl rand -base64 32 | head -c 32 | base64`
## Encrypting data
TL;DR -- Nonce reuse is a problem. AEAD isn't a clear choice right now.
[Miscreant](https://github.com/miscreant/miscreant.go)
+ AES-GCM-SIV seems to have ideal properties
+ random nonces
- ~30% slower encryption
- [not maintained by a BigCo](https://github.com/miscreant/miscreant.go/graphs/contributors)
[nacl/secretbot](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
+ Fast
+ XSalsa20 wutg Poly1305 MAC provides encryption and authentication together
+ A newer standard and may not be considered acceptable in environments that require high levels of review.
-/+ maintained as an [/x/ package](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
- doesn't use the underlying cipher.AEAD api.
GCM with random nonces
+ Fastest
+ Go standard library, supported by google $
- Easy to get wrong
- IV reuse is a known weakness so keys must be rotated before birthday attack. [NIST SP 800-38D](http://csrc.nist.gov/publications/nistpubs/800-38D/SP-800-38D.pdf) recommends using the same key with random 96-bit nonces (the default nonce length) no more than 2^32 times
Further reading on tradeoffs:
- [Introducing Miscreant](https://tonyarcieri.com/introducing-miscreant-a-multi-language-misuse-resistant-encryption-library)
- [agl's post AES-GCM-SIV](https://www.imperialviolet.org/2017/05/14/aesgcmsiv.html)
- [x/crypto: add chacha20, xchacha20](https://github.com/golang/go/issues/24485s)
- [GCM cannot be used with random nonces](https://github.com/gtank/cryptopasta/issues/14s)
- [proposal: x/crypto/chacha20poly1305: add support for XChaCha20](https://github.com/golang/go/issues/23885)
- [kubernetes](https://kubernetes.io/docs/tasks/administer-cluster/encrypt-data/#providers)

View file

@ -0,0 +1,30 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"crypto/hmac"
"crypto/sha512"
"golang.org/x/crypto/bcrypt"
)
// Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to
// be a natural-language string describing the purpose of the hash, such as
// "hash file for lookup key" or "master secret to client secret". It serves
// as an HMAC "key" and ensures that different purposes will have different
// hash output. This function is NOT suitable for hashing passwords.
func Hash(tag string, data []byte) []byte {
h := hmac.New(sha512.New512_256, []byte(tag))
h.Write(data)
return h.Sum(nil)
}
// HashPassword generates a bcrypt hash of the password using work factor 14.
func HashPassword(password []byte) ([]byte, error) {
return bcrypt.GenerateFromPassword(password, 14)
}
// CheckPasswordHash securely compares a bcrypt hashed password with its possible
// plaintext equivalent. Returns nil on success, or an error on failure.
func CheckPasswordHash(hash, password []byte) error {
return bcrypt.CompareHashAndPassword(hash, password)
}

View file

@ -0,0 +1,80 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"fmt"
"io/ioutil"
"os"
"testing"
)
func TestPasswordHashing(t *testing.T) {
bcryptTests := []struct {
plaintext []byte
hash []byte
}{
{
plaintext: []byte("password"),
hash: []byte("$2a$14$uALAQb/Lwl59oHVbuUa5m.xEFmQBc9ME/IiSgJK/VHtNJJXASCDoS"),
},
}
for _, tt := range bcryptTests {
hashed, err := HashPassword(tt.plaintext)
if err != nil {
t.Error(err)
}
if err = CheckPasswordHash(hashed, tt.plaintext); err != nil {
t.Error(err)
}
}
}
// Benchmarks SHA256 on 16K of random data.
func BenchmarkSHA256(b *testing.B) {
data, err := ioutil.ReadFile("testdata/random")
if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
_ = sha256.Sum256(data)
}
}
// Benchmarks SHA512/256 on 16K of random data.
func BenchmarkSHA512_256(b *testing.B) {
data, err := ioutil.ReadFile("testdata/random")
if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
_ = sha512.Sum512_256(data)
}
}
func BenchmarkBcrypt(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := HashPassword([]byte("thisisareallybadpassword"))
if err != nil {
b.Error(err)
break
}
}
}
func ExampleHash() {
tag := "hashing file for lookup key"
contents, err := ioutil.ReadFile("testdata/random")
if err != nil {
fmt.Printf("could not read file: %v\n", err)
os.Exit(1)
}
digest := Hash(tag, contents)
fmt.Println(hex.EncodeToString(digest))
// Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1
}

BIN
internal/cryptutil/testdata/random vendored Normal file

Binary file not shown.

View file

@ -0,0 +1,32 @@
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
import (
"errors"
"os"
)
// IsReadableFile reports whether the file exists and is readable.
// If the error is non-nil, it means there might be a file or directory
// with that name but we cannot read it.
//
// Adapted from the upspin.io source code.
func IsReadableFile(path string) (bool, error) {
// Is it stattable and is it a plain file?
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil // Item does not exist.
}
return false, err // Item is problematic.
}
if info.IsDir() {
return false, errors.New("is directory")
}
// Is it readable?
fd, err := os.Open(path)
if err != nil {
return false, errors.New("permission denied")
}
fd.Close()
return true, nil // Item exists and is readable.
}

View file

@ -0,0 +1,29 @@
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
import "testing"
func TestIsReadableFile(t *testing.T) {
tests := []struct {
name string
args string
want bool
wantErr bool
}{
{"good file", "fileutil.go", true, false},
{"file doesn't exist", "file-no-exist/nope", false, false},
{"can't read dir", "./", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := IsReadableFile(tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("IsReadableFile() error = %+v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("IsReadableFile() = %v, want %v", got, tt.want)
}
})
}
}

125
internal/https/https.go Normal file
View file

@ -0,0 +1,125 @@
package https // import "github.com/pomerium/pomerium/internal/https"
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"time"
"github.com/pomerium/pomerium/internal/fileutil"
)
// Options contains the configurations settings for a TLS http server
type Options struct {
// Addr specifies the host and port on which the server should serve
// HTTPS requests. If empty, ":https" is used.
Addr string
// CertFile and KeyFile specifies the TLS certificates to use.
CertFile string
KeyFile string
}
var defaultOptions = &Options{
Addr: ":https",
CertFile: filepath.Join(findKeyDir(), "cert.pem"),
KeyFile: filepath.Join(findKeyDir(), "privkey.pem"),
}
func findKeyDir() string {
p, err := os.Getwd()
if err != nil {
return "."
}
return p
}
func (opt *Options) applyDefaults() {
if opt.Addr == "" {
opt.Addr = defaultOptions.Addr
}
if opt.CertFile == "" {
opt.CertFile = defaultOptions.CertFile
}
if opt.KeyFile == "" {
opt.KeyFile = defaultOptions.KeyFile
}
}
// ListenAndServeTLS serves the provided handlers by HTTPS
// using the provided options.
func ListenAndServeTLS(opt *Options, handler http.Handler) error {
if opt == nil {
opt = defaultOptions
} else {
opt.applyDefaults()
}
config, err := newDefaultTLSConfig(opt.CertFile, opt.KeyFile)
if err != nil {
return fmt.Errorf("https: setting up TLS config: %v", err)
}
ln, err := net.Listen("tcp", opt.Addr)
if err != nil {
return err
}
ln = tls.NewListener(ln, config)
// Set up the main server.
server := &http.Server{
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 15 * time.Second,
// WriteTimeout is set to 0 because it also pertains to
// streaming replies, e.g., the DirServer.Watch interface.
WriteTimeout: 0,
IdleTimeout: 60 * time.Second,
TLSConfig: config,
Handler: handler,
}
return server.Serve(ln)
}
// newDefaultTLSConfig creates a new TLS config based on the certificate files given.
func newDefaultTLSConfig(certFile string, certKeyFile string) (*tls.Config, error) {
certReadable, err := fileutil.IsReadableFile(certFile)
if err != nil {
return nil, fmt.Errorf("TLS certificate in %q: %q", certFile, err)
}
if !certReadable {
return nil, fmt.Errorf("certificate file %q not readable", certFile)
}
keyReadable, err := fileutil.IsReadableFile(certKeyFile)
if err != nil {
return nil, fmt.Errorf("TLS key in %q: %v", certKeyFile, err)
}
if !keyReadable {
return nil, fmt.Errorf("certificate key file %q not readable", certKeyFile)
}
cert, err := tls.LoadX509KeyPair(certFile, certKeyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
Certificates: []tls.Certificate{cert},
}
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}

View file

@ -0,0 +1,87 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"time"
)
// ErrTokenRevoked signifies a token revokation or expiration error
var ErrTokenRevoked = errors.New("Token expired or revoked")
var httpClient = &http.Client{
Timeout: time.Second * 5,
Transport: &http.Transport{
Dial: (&net.Dialer{
Timeout: 2 * time.Second,
}).Dial,
TLSHandshakeTimeout: 2 * time.Second,
},
}
// Client provides a simple helper interface to make HTTP requests
func Client(method, endpoint, userAgent string, params url.Values, response interface{}) error {
var body io.Reader
switch method {
case "POST":
body = bytes.NewBufferString(params.Encode())
case "GET":
// error checking skipped because we are just parsing in
// order to make a copy of an existing URL
u, _ := url.Parse(endpoint)
u.RawQuery = params.Encode()
endpoint = u.String()
default:
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
}
req, err := http.NewRequest(method, endpoint, body)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", userAgent)
resp, err := httpClient.Do(req)
if err != nil {
return err
}
var respBody []byte
respBody, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
switch resp.StatusCode {
case http.StatusBadRequest:
var response struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
e := json.Unmarshal(respBody, &response)
if e == nil && response.ErrorDescription == "Token expired or revoked" {
return ErrTokenRevoked
}
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
default:
return fmt.Errorf(http.StatusText(resp.StatusCode))
}
}
if response != nil {
err := json.Unmarshal(respBody, &response)
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,82 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/version"
)
var (
// ErrUserNotAuthorized is an error for unauthorized users.
ErrUserNotAuthorized = errors.New("user not authorized")
)
// HTTPError stores the status code and a message for a given HTTP error.
type HTTPError struct {
Code int
Message string
}
// Error fulfills the error interface, returning a string representation of the error.
func (h HTTPError) Error() string {
return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message)
}
// CodeForError maps an error type and returns a corresponding http.Status
func CodeForError(err error) int {
switch err {
case ErrTokenRevoked:
return http.StatusUnauthorized
}
return http.StatusInternalServerError
}
// ErrorResponse renders an error page for errors given a message and a status code.
func ErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
if req.Header.Get("Accept") == "application/json" {
var response struct {
Error string `json:"error"`
}
response.Error = message
writeJSONResponse(rw, code, response)
} else {
title := http.StatusText(code)
log.Error().
Int("http-status", code).
Str("page-title", title).
Str("page-message", message).
Msg("authenticate/errors.ErrorResponse")
rw.WriteHeader(code)
t := struct {
Code int
Title string
Message string
Version string
}{
Code: code,
Title: title,
Message: message,
Version: version.FullVersion(),
}
templates.New().ExecuteTemplate(rw, "error.html", t)
}
}
// writeJSONResponse is a helper that sets the application/json header and writes a response.
func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(code)
err := json.NewEncoder(rw).Encode(response)
if err != nil {
io.WriteString(rw, err.Error())
}
}

129
internal/log/log.go Normal file
View file

@ -0,0 +1,129 @@
// Package log provides a global logger for zerolog.
package log // import "github.com/pomerium/pomerium/internal/log"
import (
"context"
"io"
"net/http"
"os"
"github.com/rs/zerolog"
)
// Logger is the global logger.
var Logger = zerolog.New(os.Stderr).With().Timestamp().Logger()
// Output duplicates the global logger and sets w as its output.
func Output(w io.Writer) zerolog.Logger {
return Logger.Output(w)
}
// With creates a child logger with the field added to its context.
func With() zerolog.Context {
return Logger.With()
}
// WithRequest creates a child logger with the remote user added to its context.
func WithRequest(req *http.Request, function string) zerolog.Logger {
remoteUser := getRemoteAddr(req)
return Logger.With().
Str("function", function).
Str("req-remote-user", remoteUser).
Str("req-http-method", req.Method).
Str("req-host", req.Host).
Str("req-url", req.URL.String()).
Str("req-user-agent", req.Header.Get("User-Agent")).
Logger()
}
// Level creates a child logger with the minimum accepted level set to level.
func Level(level zerolog.Level) zerolog.Logger {
return Logger.Level(level)
}
// Sample returns a logger with the s sampler.
func Sample(s zerolog.Sampler) zerolog.Logger {
return Logger.Sample(s)
}
// Hook returns a logger with the h Hook.
func Hook(h zerolog.Hook) zerolog.Logger {
return Logger.Hook(h)
}
// Debug starts a new message with debug level.
//
// You must call Msg on the returned event in order to send the event.
func Debug() *zerolog.Event {
return Logger.Debug()
}
// Info starts a new message with info level.
//
// You must call Msg on the returned event in order to send the event.
func Info() *zerolog.Event {
return Logger.Info()
}
// Warn starts a new message with warn level.
//
// You must call Msg on the returned event in order to send the event.
func Warn() *zerolog.Event {
return Logger.Warn()
}
// Error starts a new message with error level.
//
// You must call Msg on the returned event in order to send the event.
func Error() *zerolog.Event {
return Logger.Error()
}
// Fatal starts a new message with fatal level. The os.Exit(1) function
// is called by the Msg method.
//
// You must call Msg on the returned event in order to send the event.
func Fatal() *zerolog.Event {
return Logger.Fatal()
}
// Panic starts a new message with panic level. The message is also sent
// to the panic function.
//
// You must call Msg on the returned event in order to send the event.
func Panic() *zerolog.Event {
return Logger.Panic()
}
// WithLevel starts a new message with level.
//
// You must call Msg on the returned event in order to send the event.
func WithLevel(level zerolog.Level) *zerolog.Event {
return Logger.WithLevel(level)
}
// Log starts a new message with no level. Setting zerolog.GlobalLevel to
// zerolog.Disabled will still disable events produced by this method.
//
// You must call Msg on the returned event in order to send the event.
func Log() *zerolog.Event {
return Logger.Log()
}
// Print sends a log event using debug level and no extra field.
// Arguments are handled in the manner of fmt.Print.
func Print(v ...interface{}) {
Logger.Print(v...)
}
// Printf sends a log event using debug level and no extra field.
// Arguments are handled in the manner of fmt.Printf.
func Printf(format string, v ...interface{}) {
Logger.Printf(format, v...)
}
// Ctx returns the Logger associated with the ctx. If no logger
// is associated, a disabled logger is returned.
func Ctx(ctx context.Context) *zerolog.Logger {
return zerolog.Ctx(ctx)
}

View file

@ -0,0 +1,145 @@
package log // import "github.com/pomerium/pomerium/internal/log"
import (
"net/http"
"net/url"
"strings"
"time"
)
// Used to stash the authenticated user in the response for access when logging requests.
const loggingUserHeader = "SSO-Authenticated-User"
const gapMetaDataHeader = "GAP-Auth"
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
// code and body size
type responseLogger struct {
w http.ResponseWriter
status int
size int
proxyHost string
authInfo string
}
func (l *responseLogger) Header() http.Header {
return l.w.Header()
}
func (l *responseLogger) extractUser() {
authInfo := l.w.Header().Get(loggingUserHeader)
if authInfo != "" {
l.authInfo = authInfo
l.w.Header().Del(loggingUserHeader)
}
}
func (l *responseLogger) ExtractGAPMetadata() {
authInfo := l.w.Header().Get(gapMetaDataHeader)
if authInfo != "" {
l.authInfo = authInfo
l.w.Header().Del(gapMetaDataHeader)
}
}
func (l *responseLogger) Write(b []byte) (int, error) {
if l.status == 0 {
// The status will be StatusOK if WriteHeader has not been called yet
l.status = http.StatusOK
}
l.extractUser()
l.ExtractGAPMetadata()
size, err := l.w.Write(b)
l.size += size
return size, err
}
func (l *responseLogger) WriteHeader(s int) {
l.extractUser()
l.ExtractGAPMetadata()
l.w.WriteHeader(s)
l.status = s
}
func (l *responseLogger) Status() int {
return l.status
}
func (l *responseLogger) Size() int {
return l.size
}
func (l *responseLogger) Flush() {
f := l.w.(http.Flusher)
f.Flush()
}
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends
type loggingHandler struct {
handler http.Handler
}
// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer.
func NewLoggingHandler(h http.Handler) http.Handler {
return loggingHandler{
handler: h,
}
}
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
t := time.Now()
url := *req.URL
logger := &responseLogger{w: w, proxyHost: getProxyHost(req)}
h.handler.ServeHTTP(logger, req)
requestDuration := time.Since(t)
logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status())
}
// logRequest logs information about a request
func logRequest(proxyHost, username string, req *http.Request, url url.URL, requestDuration time.Duration, status int) {
uri := req.Host + url.RequestURI()
Info().
Int("http-status", status).
Str("request-method", req.Method).
Str("request-uri", uri).
Str("proxy-host", proxyHost).
Str("user-agent", req.Header.Get("User-Agent")).
Str("remote-address", getRemoteAddr(req)).
Dur("duration", requestDuration).
Str("user", username).
Msg("request")
}
// getRemoteAddr returns the client IP address from a request. If present, the
// X-Forwarded-For header is assumed to be set by a load balancer, and its
// rightmost entry (the client IP that connected to the LB) is returned.
func getRemoteAddr(req *http.Request) string {
addr := req.RemoteAddr
forwardedHeader := req.Header.Get("X-Forwarded-For")
if forwardedHeader != "" {
forwardedList := strings.Split(forwardedHeader, ",")
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
if forwardedAddr != "" {
addr = forwardedAddr
}
}
return addr
}
// getProxyHost attempts to get the proxy host from the redirect_uri parameter
func getProxyHost(req *http.Request) string {
err := req.ParseForm()
if err != nil {
return ""
}
redirect := req.Form.Get("redirect_uri")
redirectURL, err := url.Parse(redirect)
if err != nil {
return ""
}
return redirectURL.Host
}

View file

@ -0,0 +1,72 @@
package log // import "github.com/pomerium/pomerium/internal/log"
import (
"net/http/httptest"
"testing"
)
func TestGetRemoteAddr(t *testing.T) {
testCases := []struct {
name string
remoteAddr string
forwardedHeader string
expectedAddr string
}{
{
name: "RemoteAddr used when no X-Forwarded-For header is given",
remoteAddr: "1.1.1.1",
expectedAddr: "1.1.1.1",
},
{
name: "RemoteAddr used when no X-Forwarded-For header is only whitespace",
remoteAddr: "1.1.1.1",
forwardedHeader: " ",
expectedAddr: "1.1.1.1",
},
{
name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace",
remoteAddr: "1.1.1.1",
forwardedHeader: " , , ",
expectedAddr: "1.1.1.1",
},
{
name: "X-Forwarded-For header is preferred to RemoteAddr",
remoteAddr: "1.1.1.1",
forwardedHeader: "9.9.9.9",
expectedAddr: "9.9.9.9",
},
{
name: "rightmost entry in X-Forwarded-For header is used",
remoteAddr: "1.1.1.1",
forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5",
expectedAddr: "5.5.5.5",
},
{
name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty",
remoteAddr: "1.1.1.1",
forwardedHeader: "2.2.2.2, 3.3.3.3, ",
expectedAddr: "1.1.1.1",
},
{
name: "X-Forwaded-For header entries are stripped",
remoteAddr: "1.1.1.1",
forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ",
expectedAddr: "5.5.5.5",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tc.remoteAddr
if tc.forwardedHeader != "" {
req.Header.Set("X-Forwarded-For", tc.forwardedHeader)
}
addr := getRemoteAddr(req)
if addr != tc.expectedAddr {
t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr)
}
})
}
}

View file

@ -0,0 +1,205 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
)
// SetHeaders ensures that every response includes some basic security headers
func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for key, val := range securityHeaders {
rw.Header().Set(key, val)
}
h.ServeHTTP(rw, req)
})
}
// WithMethods writes an error response if the method of the request is not included.
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
methodMap := make(map[string]struct{}, len(methods))
for _, m := range methods {
methodMap[m] = struct{}{}
}
return func(rw http.ResponseWriter, req *http.Request) {
if _, ok := methodMap[req.Method]; !ok {
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
return
}
f(rw, req)
}
}
// ValidateClientID checks the request body or url for the client id and returns an error
// if it does not match the proxy client id
func ValidateClientID(f http.HandlerFunc, proxyClientID string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
// try to get the client id from the request body
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
return
}
clientID := req.FormValue("client_id")
if clientID == "" {
// try to get the clientID from the query parameters
clientID = req.URL.Query().Get("client_id")
}
if clientID != proxyClientID {
httputil.ErrorResponse(rw, req, "Invalid client_id parameter", http.StatusUnauthorized)
return
}
f(rw, req)
}
}
// ValidateClientSecret checks the request header for the client secret and returns
// an error if it does not match the proxy client secret
func ValidateClientSecret(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
return
}
clientSecret := req.Form.Get("client_secret")
// check the request header for the client secret
if clientSecret == "" {
clientSecret = req.Header.Get("X-Client-Secret")
}
if clientSecret != proxyClientSecret {
log.Error().Str("clientSecret", clientSecret).Str("proxyClientSecret", proxyClientSecret).Msg("oh")
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
return
}
f(rw, req)
}
}
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
// the url's domain is one in the list of proxy root domains.
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
if !validRedirectURI(redirectURI, proxyRootDomains) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
func validRedirectURI(uri string, rootDomains []string) bool {
redirectURL, err := url.Parse(uri)
if uri == "" || err != nil || redirectURL.Host == "" {
return false
}
for _, domain := range rootDomains {
if strings.HasSuffix(redirectURL.Hostname(), domain) {
return true
}
}
return false
}
// ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key
func ValidateSignature(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
sigVal := req.Form.Get("sig")
timestamp := req.Form.Get("ts")
if !validSignature(redirectURI, sigVal, timestamp, proxyClientSecret) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
// ValidateHost ensures that each request's host is valid
func ValidateHost(h http.Handler, mux map[string]*http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if _, ok := mux[req.Host]; !ok {
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
return
}
h.ServeHTTP(rw, req)
})
}
// RequireHTTPS reroutes a HTTP request to HTTPS
// todo(bdd) : this is unreliable unless behind another reverser proxy
func RequireHTTPS(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
dest := &url.URL{
Scheme: "https",
Host: req.Host,
Path: req.URL.Path,
RawQuery: req.URL.RawQuery,
}
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
return
}
h.ServeHTTP(rw, req)
})
}
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := url.Parse(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
i, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return false
}
tm := time.Unix(i, 0)
ttl := 5 * time.Minute
if time.Since(tm) > ttl {
return false
}
localSig := redirectURLSignature(redirectURI, tm, secret)
return hmac.Equal(requestSig, localSig)
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return h.Sum(nil)
}

View file

@ -0,0 +1,35 @@
package options // import "github.com/pomerium/pomerium/internal/options"
import (
"fmt"
"strings"
)
// NewEmailValidator returns a function that checks whether a given email is valid based on a list
// of domains. The domain "*" is a wild card that matches any non-empty email.
func NewEmailValidator(domains []string) func(string) bool {
allowAll := false
for i, domain := range domains {
if domain == "*" {
allowAll = true
}
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
}
if allowAll {
return func(email string) bool { return email != "" }
}
return func(email string) bool {
if email == "" {
return false
}
email = strings.ToLower(email)
for _, domain := range domains {
if strings.HasSuffix(email, domain) {
return true
}
}
return false
}
}

View file

@ -0,0 +1,121 @@
package options // import "github.com/pomerium/pomerium/internal/options"
import (
"testing"
)
func TestEmailValidatorValidator(t *testing.T) {
testCases := []struct {
name string
domains []string
email string
expectValid bool
}{
{
name: "nothing should validate when domain list is empty",
domains: []string(nil),
email: "foo@example.com",
expectValid: false,
},
{
name: "single domain validation",
domains: []string{"example.com"},
email: "foo@example.com",
expectValid: true,
},
{
name: "substring matches are rejected",
domains: []string{"example.com"},
email: "foo@hackerexample.com",
expectValid: false,
},
{
name: "no subdomain rollup happens",
domains: []string{"example.com"},
email: "foo@bar.example.com",
expectValid: false,
},
{
name: "multiple domain validation still rejects other domains",
domains: []string{"abc.com", "xyz.com"},
email: "foo@example.com",
expectValid: false,
},
{
name: "multiple domain validation still accepts emails from either domain",
domains: []string{"abc.com", "xyz.com"},
email: "foo@abc.com",
expectValid: true,
},
{
name: "multiple domain validation still rejects other domains",
domains: []string{"abc.com", "xyz.com"},
email: "bar@xyz.com",
expectValid: true,
},
{
name: "comparisons are case insensitive",
domains: []string{"Example.Com"},
email: "foo@example.com",
expectValid: true,
},
{
name: "comparisons are case insensitive",
domains: []string{"Example.Com"},
email: "foo@EXAMPLE.COM",
expectValid: true,
},
{
name: "comparisons are case insensitive",
domains: []string{"example.com"},
email: "foo@ExAmPlE.CoM",
expectValid: true,
},
{
name: "single wildcard allows all",
domains: []string{"*"},
email: "foo@example.com",
expectValid: true,
},
{
name: "single wildcard allows all",
domains: []string{"*"},
email: "bar@gmail.com",
expectValid: true,
},
{
name: "wildcard in list allows all",
domains: []string{"example.com", "*"},
email: "foo@example.com",
expectValid: true,
},
{
name: "wildcard in list allows all",
domains: []string{"example.com", "*"},
email: "foo@gmail.com",
expectValid: true,
},
{
name: "empty email rejected",
domains: []string{"example.com"},
email: "",
expectValid: false,
},
{
name: "wildcard still rejects empty emails",
domains: []string{"*"},
email: "",
expectValid: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
emailValidator := NewEmailValidator(tc.domains)
valid := emailValidator(tc.email)
if valid != tc.expectValid {
t.Fatalf("expected %v, got %v", tc.expectValid, valid)
}
})
}
}

View file

@ -0,0 +1,163 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/aead"
"github.com/pomerium/pomerium/internal/log"
)
// ErrInvalidSession is an error for invalid sessions.
var ErrInvalidSession = errors.New("invalid session")
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
LoadSession(*http.Request) (*SessionState, error)
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
}
// CookieStore represents all the cookie related configurations
type CookieStore struct {
Name string
CSRFCookieName string
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
CookieCipher aead.Cipher
SessionLifetimeTTL time.Duration
}
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
return func(s *CookieStore) error {
cipher, err := aead.NewMiscreantCipher(cookieSecret)
if err != nil {
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
}
s.CookieCipher = cipher
return nil
}
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
c := &CookieStore{
Name: cookieName,
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
}
for _, f := range optFuncs {
err := f(c)
if err != nil {
return nil, err
}
}
domain := c.CookieDomain
if domain == "" {
domain = "<default>"
}
return c, nil
}
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
}
if s.CookieDomain != "" {
if !strings.HasSuffix(domain, s.CookieDomain) {
log.Warn().Str("cookie-domain", s.CookieDomain).Msg("using configured cookie domain")
}
domain = s.CookieDomain
}
return &http.Cookie{
Name: name,
Value: value,
Path: "/",
Domain: domain,
HttpOnly: s.CookieHTTPOnly,
Secure: s.CookieSecure,
Expires: now.Add(expiration),
}
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return s.makeCookie(req, s.Name, value, expiration, now)
}
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time.
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return s.makeCookie(req, s.CSRFCookieName, value, expiration, now)
}
// ClearCSRF clears the CSRF cookie from the request
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
return req.Cookie(s.CSRFCookieName)
}
// ClearSession clears the session cookie from a request
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
}
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
}
// LoadSession returns a SessionState from the cookie in the request.
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
c, err := req.Cookie(s.Name)
if err != nil {
// always http.ErrNoCookie
return nil, err
}
session, err := UnmarshalSession(c.Value, s.CookieCipher)
if err != nil {
log.Error().Err(err).Str("remote-host", req.Host).Msg("error unmarshaling session")
return nil, ErrInvalidSession
}
return session, nil
}
// SaveSession saves a session state to a request sessions.
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
value, err := MarshalSession(sessionState, s.CookieCipher)
if err != nil {
return err
}
s.setSessionCookie(rw, req, value)
return nil
}

View file

@ -0,0 +1,348 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/pomerium/pomerium/internal/testutil"
)
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
func TestCreateMiscreantCookieCipher(t *testing.T) {
testCases := []struct {
name string
cookieSecret []byte
expectedError bool
}{
{
name: "normal case with base64 encoded secret",
cookieSecret: testEncodedCookieSecret,
},
{
name: "error when not base64 encoded",
cookieSecret: []byte("abcd"),
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewCookieStore("cookieName", CreateMiscreantCookieCipher(tc.cookieSecret))
if !tc.expectedError {
testutil.Ok(t, err)
} else {
testutil.NotEqual(t, err, nil)
}
})
}
}
func TestNewSession(t *testing.T) {
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedError bool
expectedSession *CookieStore
}{
{
name: "default with no opt funcs set",
expectedSession: &CookieStore{
Name: "cookieName",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: "cookieName_csrf",
},
},
{
name: "opt func with an error returns an error",
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
expectedError: true,
},
{
name: "opt func overrides default values",
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
s.CookieExpire = time.Hour
return nil
}},
expectedSession: &CookieStore{
Name: "cookieName",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: time.Hour,
CSRFCookieName: "cookieName_csrf",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore("cookieName", tc.optFuncs...)
if tc.expectedError {
testutil.NotEqual(t, err, nil)
} else {
testutil.Ok(t, err)
}
testutil.Equal(t, tc.expectedSession, session)
})
}
}
func TestMakeSessionCookie(t *testing.T) {
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
testutil.Equal(t, cookie, tc.expectedCookie)
})
}
}
func TestMakeSessionCSRFCookie(t *testing.T) {
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
csrfName := "cookieName_csrf"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
testutil.Equal(t, tc.expectedCookie, cookie)
})
}
}
func TestSetSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.setSessionCookie(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestSetCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.SetCSRF(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearSession(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("clear csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearCSRF(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestLoadCookiedSession(t *testing.T) {
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
expectedError error
sessionState *SessionState
}{
{
name: "no cookie set returns an error",
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
expectedError: http.ErrNoCookie,
},
{
name: "cookie set with cipher set",
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value, err := MarshalSession(sessionState, s.CookieCipher)
testutil.Ok(t, err)
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
sessionState: &SessionState{
Email: "example@email.com",
RefreshToken: "abccdddd",
AccessToken: "access",
},
},
{
name: "cookie set with invalid value cipher set",
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value := "574b776a7c934d6b9fc42ec63a389f79"
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
expectedError: ErrInvalidSession,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "https://www.example.com", nil)
tc.setupCookies(t, req, session, tc.sessionState)
s, err := session.LoadSession(req)
testutil.Equal(t, tc.expectedError, err)
testutil.Equal(t, tc.sessionState, s)
})
}
}

View file

@ -0,0 +1,50 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
)
// MockCSRFStore is a mock implementation of the CSRF store interface
type MockCSRFStore struct {
ResponseCSRF string
Cookie *http.Cookie
GetError error
}
// SetCSRF sets the ResponseCSRF string to a val
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
ms.ResponseCSRF = val
}
// ClearCSRF clears the ResponseCSRF string
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
ms.ResponseCSRF = ""
}
// GetCSRF returns the cookie and error
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
return ms.Cookie, ms.GetError
}
// MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct {
ResponseSession string
Session *SessionState
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
return ms.SaveError
}

View file

@ -0,0 +1,70 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"time"
"github.com/pomerium/pomerium/internal/aead"
)
var (
// ErrLifetimeExpired is an error for the lifetime deadline expiring
ErrLifetimeExpired = errors.New("user lifetime expired")
)
// SessionState is our object that keeps track of a user's session state
type SessionState struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"` // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
RefreshDeadline time.Time `json:"refresh_deadline"`
LifetimeDeadline time.Time `json:"lifetime_deadline"`
ValidDeadline time.Time `json:"valid_deadline"`
GracePeriodStart time.Time `json:"grace_period_start"`
Email string `json:"email"`
User string `json:"user"`
Groups []string `json:"groups"`
}
// LifetimePeriodExpired returns true if the lifetime has expired
func (s *SessionState) LifetimePeriodExpired() bool {
return isExpired(s.LifetimeDeadline)
}
// RefreshPeriodExpired returns true if the refresh period has expired
func (s *SessionState) RefreshPeriodExpired() bool {
return isExpired(s.RefreshDeadline)
}
// ValidationPeriodExpired returns true if the validation period has expired
func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}
func isExpired(t time.Time) bool {
return t.Before(time.Now())
}
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
// given cipher, and base64-encodes the result
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) {
return c.Marshal(s)
}
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
s := &SessionState{}
err := c.Unmarshal(value, s)
if err != nil {
return nil, err
}
return s, nil
}
// ExtendDeadline returns the time extended by a given duration
func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

View file

@ -0,0 +1,71 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"reflect"
"testing"
"time"
"github.com/pomerium/pomerium/internal/aead"
)
func TestSessionStateSerialization(t *testing.T) {
secret := aead.GenerateKey()
c, err := aead.NewMiscreantCipher([]byte(secret))
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
}
want := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}
ciphertext, err := MarshalSession(want, c)
if err != nil {
t.Fatalf("expected to be encode session: %v", err)
}
got, err := UnmarshalSession(ciphertext, c)
if err != nil {
t.Fatalf("expected to be decode session: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Logf("want: %#v", want)
t.Logf(" got: %#v", got)
t.Errorf("encoding and decoding session resulted in unexpected output")
}
}
func TestSessionStateExpirations(t *testing.T) {
session := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
Email: "user@domain.com",
User: "user",
}
if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
if !session.ValidationPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
}

View file

@ -0,0 +1,75 @@
// Original Copyright 2013 The Go Authors. All rights reserved.
//
// Modified by BuzzFeed to return duplicate counts.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression mechanism.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Count bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value of Count indicates how many tiems v was given to multiple callers.
// Count will be zero for requests are shared and only be non-zero for the originating request.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, 0, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.dups, c.err
}

View file

@ -0,0 +1,87 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, _, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, _, err := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
return v, nil
}
const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, _, err := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

View file

@ -0,0 +1,199 @@
package templates // import "github.com/pomerium/pomerium/internal/templates"
import (
"html/template"
)
// New loads html and style resources directly. Panics on failure.
func New() *template.Template {
t := template.New("authenticate-templates")
template.Must(t.Parse(`
{{define "header.html"}}
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
<style>
* {
margin: 0;
padding: 0;
}
body {
font-family: "Helvetica Neue",Helvetica,Arial,sans-serif;
font-size: 1em;
line-height: 1.42857143;
color: #333;
background: #f0f0f0;
}
p {
margin: 1.5em 0;
}
p:first-child {
margin-top: 0;
}
p:last-child {
margin-bottom: 0;
}
.container {
max-width: 40em;
display: block;
margin: 10% auto;
text-align: center;
}
.content, .message, button {
border: 1px solid rgba(0,0,0,.125);
border-bottom-width: 4px;
border-radius: 4px;
}
.content, .message {
background-color: #fff;
padding: 2rem;
margin: 1rem 0;
}
.error, .message {
border-bottom-color: #c00;
}
.message {
padding: 1.5rem 2rem 1.3rem;
}
header {
border-bottom: 1px solid rgba(0,0,0,.075);
margin: -2rem 0 2rem;
padding: 2rem 0 1.8rem;
}
header h1 {
font-size: 1.5em;
font-weight: normal;
}
.error header {
color: #c00;
}
.details {
font-size: .85rem;
color: #999;
}
button {
color: #fff;
background-color: #3B8686;
cursor: pointer;
font-size: 1.5rem;
font-weight: bold;
padding: 1rem 2.5rem;
text-shadow: 0 3px 1px rgba(0,0,0,.2);
outline: none;
}
button:active {
border-top-width: 4px;
border-bottom-width: 1px;
text-shadow: none;
}
footer {
font-size: 0.75em;
color: #999;
text-align: right;
margin: 1rem;
}
</style>
{{end}}`))
t = template.Must(t.Parse(`{{define "footer.html"}}Secured by <b>pomerium</b> {{end}}`))
t = template.Must(t.Parse(`
{{define "sign_in_message.html"}}
{{if eq (len .EmailDomains) 1}}
{{if eq (index .EmailDomains 0) "@*"}}
<p>You may sign in with any {{.ProviderName}} account.</p>
{{else}}
<p>You may sign in with your <b>{{index .EmailDomains 0}}</b> {{.ProviderName}} account.</p>
{{end}}
{{else if gt (len .EmailDomains) 1}}
<p>
You may sign in with any of these {{.ProviderName}} accounts:<br>
{{range $i, $e := .EmailDomains}}{{if $i}}, {{end}}<b>{{$e}}</b>{{end}}
</p>
{{end}}
{{end}}`))
t = template.Must(t.Parse(`
{{define "sign_in.html"}}
<!DOCTYPE html>
<html lang="en" charset="utf-8">
<head>
<title>Sign In</title>
{{template "header.html"}}
</head>
<body>
<div class="container">
<div class="content">
<header>
<h1>Sign in to <b>{{.Destination}}</b></h1>
</header>
{{template "sign_in_message.html" .}}
<form method="GET" action="/start">
<input type="hidden" name="redirect_uri" value="{{.Redirect}}">
<button type="submit" class="btn">Sign in with {{.ProviderName}}</button>
</form>
</div>
<footer>{{template "footer.html"}} </br> {{.Version}} </footer>
</div>
</body>
</html>
{{end}}`))
template.Must(t.Parse(`
{{define "error.html"}}
<!DOCTYPE html>
<html lang="en" charset="utf-8">
<head>
<title>Error</title>
{{template "header.html"}}
</head>
<body>
<div class="container">
<div class="content error">
<header>
<h1>{{.Title}}</h1>
</header>
<p>
{{.Message}}<br>
<span class="details">HTTP {{.Code}}</span>
</p>
</div>
<footer>{{template "footer.html"}} </br> {{.Version}} </footer>
</div>
</body>
</html>{{end}}`))
t = template.Must(t.Parse(`
{{define "sign_out.html"}}
<!DOCTYPE html>
<html lang="en" charset="utf-8">
<head>
<title>Sign Out</title>
{{template "header.html"}}
</head>
<body>
<div class="container">
{{ if .Message }}
<div class="message">{{.Message}}</div>
{{ end}}
<div class="content">
<header>
<h1>Sign out of <b>{{.Destination}}</b></h1>
</header>
<p>You're currently signed in as <b>{{.Email}}</b>. This will also sign you out of other internal apps.</p>
<form method="POST" action="/sign_out">
<input type="hidden" name="redirect_uri" value="{{.Redirect}}">
<input type="hidden" name="sig" value="{{.Signature}}">
<input type="hidden" name="ts" value="{{.Timestamp}}">
<button type="submit">Sign out</button>
</form>
</div>
<footer>{{template "footer.html"}} </br> {{.Version}}</footer>
</div>
</body>
</html>
{{end}}`))
return t
}

View file

@ -0,0 +1,12 @@
package templates // import "github.com/pomerium/pomerium/internal/templates"
import (
"testing"
"github.com/pomerium/pomerium/internal/testutil"
)
func TestTemplatesCompile(t *testing.T) {
templates := New()
testutil.NotEqual(t, templates, nil)
}

View file

@ -0,0 +1,46 @@
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
// testing util functions copied from https://github.com/benbjohnson/testing
import (
"fmt"
"path/filepath"
"reflect"
"runtime"
"testing"
)
// Assert fails the test if the condition is false.
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
if !condition {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
tb.FailNow()
}
}
// Ok fails the test if an err is not nil.
func Ok(tb testing.TB, err error) {
if err != nil {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
tb.FailNow()
}
}
// Equal fails the test if exp is not equal to act.
func Equal(tb testing.TB, exp, act interface{}) {
if !reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}
// NotEqual fails the test if exp is equal to act.
func NotEqual(tb testing.TB, exp, act interface{}) {
if reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}

View file

@ -0,0 +1,42 @@
package version // import "github.com/pomerium/pomerium/internal/version"
import (
"fmt"
"runtime"
"strings"
)
var (
// ProjectName is the canonical project name set by ldl flags
ProjectName = ""
// ProjectURL is the canonical project url set by ldl flags
ProjectURL = ""
// Version specifies Semantic versioning increment (MAJOR.MINOR.PATCH).
Version = "v0.0.0"
// GitCommit specifies the git commit sha, set by the compiler.
GitCommit = ""
// BuildMeta specifies release type (dev,rc1,beta,etc)
BuildMeta = ""
runtimeVersion = runtime.Version()
)
// FullVersion returns a version string.
func FullVersion() string {
var sb strings.Builder
sb.Grow(len(Version) + len(GitCommit) + len(BuildMeta) + len("-") + len("+"))
sb.WriteString(Version)
if BuildMeta != "" {
sb.WriteString("-" + BuildMeta)
}
if GitCommit != "" {
sb.WriteString("+" + GitCommit)
}
return sb.String()
}
// UserAgent returns a user-agent string as specified in RFC 2616:14.43
// https://tools.ietf.org/html/rfc2616
func UserAgent() string {
return fmt.Sprintf("%s/%s (+%s; %s; %s)", ProjectName, Version, ProjectURL, GitCommit, runtimeVersion)
}

View file

@ -0,0 +1,71 @@
package version // import "github.com/pomerium/pomerium/internal/version"
import (
"fmt"
"runtime"
"testing"
)
func TestFullVersionVersion(t *testing.T) {
tests := []struct {
Version string
GitCommit string
BuildMeta string
expected string
}{
{"", "", "", ""},
{"1.0.0", "", "", "1.0.0"},
{"1.0.0", "314501b", "", "1.0.0+314501b"},
{"1.0.0", "314501b", "dev", "1.0.0-dev+314501b"},
}
for _, tt := range tests {
Version = tt.Version
GitCommit = tt.GitCommit
BuildMeta = tt.BuildMeta
if got := FullVersion(); got != tt.expected {
t.Errorf("expected (%s) got (%s) for Version(%s), GitCommit(%s) BuildMeta(%s)",
tt.expected,
got,
tt.Version,
tt.GitCommit,
tt.BuildMeta)
}
}
}
func BenchmarkFullVersion(b *testing.B) {
Version = "1.0.0"
GitCommit = "314501b"
BuildMeta = "dev"
for i := 0; i < b.N; i++ {
FullVersion()
}
}
func TestUserAgent(t *testing.T) {
tests := []struct {
name string
Version string
GitCommit string
BuildMeta string
ProjectName string
ProjectURL string
want string
}{
{"good user agent", "1.0.0", "314501b", "dev", "pomerium", "github.com/pomerium", fmt.Sprintf("pomerium/1.0.0 (+github.com/pomerium; 314501b; %s)", runtime.Version())},
}
for _, tt := range tests {
Version = tt.Version
GitCommit = tt.GitCommit
BuildMeta = tt.BuildMeta
ProjectName = tt.ProjectName
ProjectURL = tt.ProjectURL
t.Run(tt.name, func(t *testing.T) {
if got := UserAgent(); got != tt.want {
t.Errorf("UserAgent() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,317 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
var defaultHTTPClient = &http.Client{
Timeout: time.Second * 5,
Transport: &http.Transport{
Dial: (&net.Dialer{
Timeout: 2 * time.Second,
}).Dial,
TLSHandshakeTimeout: 2 * time.Second,
},
}
// Errors
var (
ErrMissingRefreshToken = errors.New("missing refresh token")
ErrAuthProviderUnavailable = errors.New("auth provider unavailable")
)
// AuthenticateClient holds the data associated with the AuthenticateClients
// necessary to implement a AuthenticateClient interface.
type AuthenticateClient struct {
AuthenticateServiceURL *url.URL
//
ClientID string
ClientSecret string
SignInURL *url.URL
SignOutURL *url.URL
RedeemURL *url.URL
RefreshURL *url.URL
ProfileURL *url.URL
ValidateURL *url.URL
SessionValidTTL time.Duration
SessionLifetimeTTL time.Duration
GracePeriodTTL time.Duration
}
// NewAuthenticateClient instantiates a new AuthenticateClient with provider data
func NewAuthenticateClient(uri *url.URL, clientID, clientSecret string, sessionValid, sessionLifetime, gracePeriod time.Duration) *AuthenticateClient {
return &AuthenticateClient{
AuthenticateServiceURL: uri,
ClientID: clientID,
ClientSecret: clientSecret,
SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}),
SignOutURL: uri.ResolveReference(&url.URL{Path: "/sign_out"}),
RedeemURL: uri.ResolveReference(&url.URL{Path: "/redeem"}),
RefreshURL: uri.ResolveReference(&url.URL{Path: "/refresh"}),
ValidateURL: uri.ResolveReference(&url.URL{Path: "/validate"}),
ProfileURL: uri.ResolveReference(&url.URL{Path: "/profile"}),
SessionValidTTL: sessionValid,
SessionLifetimeTTL: sessionLifetime,
GracePeriodTTL: gracePeriod,
}
}
func (p *AuthenticateClient) newRequest(method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", version.UserAgent())
req.Header.Set("Accept", "application/json")
req.Host = p.AuthenticateServiceURL.Host
return req, nil
}
func isProviderUnavailable(statusCode int) bool {
return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable
}
func extendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}
func (p *AuthenticateClient) withinGracePeriod(s *sessions.SessionState) bool {
if s.GracePeriodStart.IsZero() {
s.GracePeriodStart = time.Now()
}
return s.GracePeriodStart.Add(p.GracePeriodTTL).After(time.Now())
}
// Redeem takes a redirectURL and code and redeems the SessionState
func (p *AuthenticateClient) Redeem(redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
return nil, errors.New("missing code")
}
params := url.Values{}
// params that are validates by the authenticate service middleware
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("code", code)
req, err := p.newRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := defaultHTTPClient.Do(req)
if err != nil {
return nil, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
if isProviderUnavailable(resp.StatusCode) {
return nil, ErrAuthProviderUnavailable
}
return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
Email string `json:"email"`
}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
return nil, err
}
user := strings.Split(jsonResponse.Email, "@")[0]
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
RefreshToken: jsonResponse.RefreshToken,
IDToken: jsonResponse.IDToken,
RefreshDeadline: extendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second),
LifetimeDeadline: extendDeadline(p.SessionLifetimeTTL),
ValidDeadline: extendDeadline(p.SessionValidTTL),
Email: jsonResponse.Email,
User: user,
}, nil
}
// RefreshSession refreshes the current session
func (p *AuthenticateClient) RefreshSession(s *sessions.SessionState) (bool, error) {
if s.RefreshToken == "" {
return false, ErrMissingRefreshToken
}
newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
if err != nil {
// When we detect that the auth provider is not explicitly denying
// authentication, and is merely unavailable, we refresh and continue
// as normal during the "grace period"
if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) {
s.RefreshDeadline = extendDeadline(p.SessionValidTTL)
return true, nil
}
return false, err
}
s.AccessToken = newToken
s.RefreshDeadline = extendDeadline(duration)
s.GracePeriodStart = time.Time{}
log.Info().Str("user", s.Email).Msg("proxy/authenticator.RefreshSession")
return true, nil
}
func (p *AuthenticateClient) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("refresh_token", refreshToken)
var req *http.Request
req, err = p.newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := defaultHTTPClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != http.StatusCreated {
if isProviderUnavailable(resp.StatusCode) {
err = ErrAuthProviderUnavailable
} else {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RefreshURL.String(), body)
}
return
}
var data struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
}
err = json.Unmarshal(body, &data)
if err != nil {
return
}
token = data.AccessToken
expires = time.Duration(data.ExpiresIn) * time.Second
return
}
// ValidateSessionState validates the current sessions state
func (p *AuthenticateClient) ValidateSessionState(s *sessions.SessionState) bool {
// we validate the user's access token is valid
params := url.Values{}
params.Add("client_id", p.ClientID)
req, err := p.newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil)
if err != nil {
log.Error().Err(err).Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : error validating session state")
return false
}
req.Header.Set("X-Client-Secret", p.ClientSecret)
req.Header.Set("X-Access-Token", s.AccessToken)
req.Header.Set("X-Id-Token", s.IDToken)
resp, err := defaultHTTPClient.Do(req)
if err != nil {
log.Error().Err(err).Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : error making request to validate access token")
return false
}
if resp.StatusCode != http.StatusOK {
// When we detect that the auth provider is not explicitly denying
// authentication, and is merely unavailable, we validate and continue
// as normal during the "grace period"
if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) {
//tags := []string{"action:validate_session", "error:validation_failed"}
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
return true
}
log.Info().Str("user", s.Email).Int("status-code", resp.StatusCode).Msg("proxy/authenticator.ValidateSessionState : could not validate user access token")
return false
}
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
s.GracePeriodStart = time.Time{}
log.Info().Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : validated session")
return true
}
// signRedirectURL signs the redirect url string, given a timestamp, and returns it
func (p *AuthenticateClient) signRedirectURL(rawRedirect string, timestamp time.Time) string {
h := hmac.New(sha256.New, []byte(p.ClientSecret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
// GetSignInURL with typical oauth parameters
func (p *AuthenticateClient) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
a := *p.SignInURL
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", rawRedirect)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
params.Add("state", state)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now))
a.RawQuery = params.Encode()
return &a
}
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
func (p *AuthenticateClient) GetSignOutURL(redirectURL *url.URL) *url.URL {
a := *p.SignOutURL
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery)
params.Add("redirect_uri", rawRedirect)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now))
a.RawQuery = params.Encode()
return &a
}

520
proxy/handlers.go Normal file
View file

@ -0,0 +1,520 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
"github.com/pomerium/pomerium/internal/aead"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
const loggingUserHeader = "SSO-Authenticated-User"
var (
//ErrUserNotAuthorized is set when user is not authorized to access a resource
ErrUserNotAuthorized = errors.New("user not authorized")
)
var securityHeaders = map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}
// Handler returns a http handler for an Proxy
func (p *Proxy) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/favicon.ico", p.Favicon)
mux.HandleFunc("/robots.txt", p.RobotsTxt)
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
mux.HandleFunc("/.pomerium/callback", p.OAuthCallback)
mux.HandleFunc("/.pomerium/auth", p.AuthenticateOnly)
mux.HandleFunc("/", p.Proxy)
// Global middleware, which will be applied to each request in reverse
// order as applied here (i.e., we want to validate the host _first_ when
// processing a request)
var handler http.Handler = mux
// todo(bdd) : investigate if setting non-overridable headers makes sense
// handler = p.setResponseHeaderOverrides(handler)
handler = middleware.SetHeaders(handler, securityHeaders)
handler = middleware.ValidateHost(handler, p.mux)
handler = middleware.RequireHTTPS(handler)
handler = log.NewLoggingHandler(handler)
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// Skip host validation for /ping requests because they hit the LB directly.
if req.URL.Path == "/ping" {
p.PingPage(rw, req)
return
}
handler.ServeHTTP(rw, req)
})
}
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
func (p *Proxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
}
// Favicon will proxy the request as usual if the user is already authenticated
// but responds with a 404 otherwise, to avoid spurious and confusing
// authentication attempts when a browser automatically requests the favicon on
// an error page.
func (p *Proxy) Favicon(rw http.ResponseWriter, req *http.Request) {
err := p.Authenticate(rw, req)
if err != nil {
rw.WriteHeader(http.StatusNotFound)
return
}
p.Proxy(rw, req)
}
// PingPage send back a 200 OK response.
func (p *Proxy) PingPage(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
}
// SignOut redirects the request to the sign out url.
func (p *Proxy) SignOut(rw http.ResponseWriter, req *http.Request) {
p.sessionStore.ClearSession(rw, req)
var scheme string
// Build redirect URI from request host
if req.URL.Scheme == "" {
scheme = "https"
}
redirectURL := &url.URL{
Scheme: scheme,
Host: req.Host,
Path: "/",
}
fullURL := p.authenticateClient.GetSignOutURL(redirectURL)
http.Redirect(rw, req, fullURL.String(), http.StatusFound)
}
// XHRError returns a simple error response with an error message to the application if the request is an XML request
func (p *Proxy) XHRError(rw http.ResponseWriter, req *http.Request, code int, err error) {
jsonError := struct {
Error error `json:"error"`
}{
Error: err,
}
jsonBytes, err := json.Marshal(jsonError)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
requestLog := log.WithRequest(req, "proxy.ErrorPage")
requestLog.Error().Err(err).Int("http-status", code).Msg("proxy.XHRError")
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(code)
rw.Write(jsonBytes)
}
// ErrorPage renders an error page with a given status code, title, and message.
func (p *Proxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) {
if p.isXHR(req) {
p.XHRError(rw, req, code, errors.New(message))
return
}
requestLog := log.WithRequest(req, "proxy.ErrorPage")
requestLog.Info().
Str("page-title", title).
Str("page-message", message).
Msg("proxy.ErrorPage")
rw.WriteHeader(code)
t := struct {
Code int
Title string
Message string
Version string
}{
Code: code,
Title: title,
Message: message,
Version: version.FullVersion(),
}
p.templates.ExecuteTemplate(rw, "error.html", t)
}
func (p *Proxy) isXHR(req *http.Request) bool {
return req.Header.Get("X-Requested-With") == "XMLHttpRequest"
}
// OAuthStart begins the authentication flow, encrypting the redirect url
// in a request to the provider's sign in endpoint.
func (p *Proxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
// The proxy redirects to the authenticator, and provides it with redirectURI (which points
// back to the sso proxy).
requestLog := log.WithRequest(req, "proxy.OAuthStart")
if p.isXHR(req) {
e := errors.New("cannot continue oauth flow on xhr")
requestLog.Error().Err(e).Msg("isXHR")
p.XHRError(rw, req, http.StatusUnauthorized, e)
return
}
requestURI := req.URL.String()
callbackURL := p.GetRedirectURL(req.Host)
// generate nonce
key := aead.GenerateKey()
// state prevents cross site forgery and maintain state across the client and server
state := &StateParameter{
SessionID: fmt.Sprintf("%x", key), // nonce
RedirectURI: requestURI, // where to redirect the user back to
}
// we encrypt this value to be opaque the browser cookie
// this value will be unique since we always use a randomized nonce as part of marshaling
encryptedCSRF, err := p.CookieCipher.Marshal(state)
if err != nil {
requestLog.Error().Err(err).Msg("failed to marshal csrf")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
return
}
p.csrfStore.SetCSRF(rw, req, encryptedCSRF)
// we encrypt this value to be opaque the uri query value
// this value will be unique since we always use a randomized nonce as part of marshaling
encryptedState, err := p.CookieCipher.Marshal(state)
if err != nil {
requestLog.Error().Err(err).Msg("failed to encrypt cookie")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
return
}
signinURL := p.authenticateClient.GetSignInURL(callbackURL, encryptedState)
requestLog.Info().Msg("redirecting to begin auth flow")
http.Redirect(rw, req, signinURL.String(), http.StatusFound)
}
// OAuthCallback validates the cookie sent back from the provider, then validates
// the user information, and if authorized, redirects the user back to the original
// application.
func (p *Proxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
// We receive the callback from the SSO Authenticator. This request will either contain an
// error, or it will contain a `code`; the code can be used to fetch an access token, and
// other metadata, from the authenticator.
requestLog := log.WithRequest(req, "proxy.OAuthCallback")
// finish the oauth cycle
err := req.ParseForm()
if err != nil {
requestLog.Error().Err(err).Msg("failed parsing request form")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorString)
return
}
// We begin the process of redeeming the code for an access token.
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
if err != nil {
requestLog.Error().Err(err).Msg("error redeeming authorization code")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
return
}
encryptedState := req.Form.Get("state")
stateParameter := &StateParameter{}
err = p.CookieCipher.Unmarshal(encryptedState, stateParameter)
if err != nil {
requestLog.Error().Err(err).Msg("could not unmarshal state")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
return
}
c, err := req.Cookie(p.CSRFCookieName)
if err != nil {
requestLog.Error().Err(err).Msg("failed parsing csrf cookie")
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", err.Error())
return
}
p.csrfStore.ClearCSRF(rw, req)
encryptedCSRF := c.Value
csrfParameter := &StateParameter{}
err = p.CookieCipher.Unmarshal(encryptedCSRF, csrfParameter)
if err != nil {
requestLog.Error().Err(err).Msg("couldn't unmarshal CSRF")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
return
}
if encryptedState == encryptedCSRF {
requestLog.Error().Msg("encrypted state and CSRF should not be equal")
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
return
}
if !reflect.DeepEqual(stateParameter, csrfParameter) {
requestLog.Error().Msg("state and CSRF should be equal")
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
return
}
// We validate the user information, and check that this user has proper authorization
// for the resources requested. This can be set via the email address or any groups.
//
// set cookie, or deny
if !p.EmailValidator(session.Email) {
requestLog.Error().Str("user", session.Email).Msg("permission denied: unauthorized")
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "Invalid Account")
return
}
// We store the session in a cookie and redirect the user back to the application
err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
requestLog.Error().Msg("error saving session")
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
return
}
// This is the redirect back to the original requested application
http.Redirect(rw, req, stateParameter.RedirectURI, http.StatusFound)
}
// AuthenticateOnly calls the Authenticate handler.
func (p *Proxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
err := p.Authenticate(rw, req)
if err != nil {
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
}
rw.WriteHeader(http.StatusAccepted)
}
// Proxy authenticates a request, either proxying the request if it is authenticated, or starting the authentication process if not.
func (p *Proxy) Proxy(rw http.ResponseWriter, req *http.Request) {
// Attempts to validate the user and their cookie.
// start := time.Now()
var err error
err = p.Authenticate(rw, req)
// If the authentication is not successful we proceed to start the OAuth Flow with
// OAuthStart. If authentication is successful, we proceed to proxy to the configured
// upstream.
requestLog := log.WithRequest(req, "proxy.Proxy")
if err != nil {
switch err {
case http.ErrNoCookie:
// No cookie is set, start the oauth flow
p.OAuthStart(rw, req)
return
case ErrUserNotAuthorized:
// We know the user is not authorized for the request, we show them a forbidden page
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
return
case sessions.ErrLifetimeExpired:
// User's lifetime expired, we trigger the start of the oauth flow
p.OAuthStart(rw, req)
return
case sessions.ErrInvalidSession:
// The user session is invalid and we can't decode it.
// This can happen for a variety of reasons but the most common non-malicious
// case occurs when the session encoding schema changes. We manage this ux
// by triggering the start of the oauth flow.
p.OAuthStart(rw, req)
return
default:
requestLog.Error().Err(err).Msg("unknown error")
// We don't know exactly what happened, but authenticating the user failed, show an error
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "An unexpected error occurred")
return
}
}
// We have validated the users request and now proxy their request to the provided upstream.
route, ok := p.router(req)
if !ok {
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
return
}
// overhead := time.Now().Sub(start)
route.ServeHTTP(rw, req)
}
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
// clearing the session cookie if it's invalid and returning an error if necessary..
func (p *Proxy) Authenticate(rw http.ResponseWriter, req *http.Request) (err error) {
// Clear the session cookie if anything goes wrong.
defer func() {
if err != nil {
p.sessionStore.ClearSession(rw, req)
}
}()
requestLog := log.WithRequest(req, "proxy.Authenticate")
session, err := p.sessionStore.LoadSession(req)
if err != nil {
// We loaded a cookie but it wasn't valid, clear it, and reject the request
requestLog.Error().Err(err).Msg("error authenticating user")
return err
}
// Lifetime period is the entire duration in which the session is valid.
// This should be set to something like 14 to 30 days.
if session.LifetimePeriodExpired() {
requestLog.Warn().Str("user", session.Email).Msg("session lifetime has expired")
return sessions.ErrLifetimeExpired
} else if session.RefreshPeriodExpired() {
// Refresh period is the period in which the access token is valid. This is ultimately
// controlled by the upstream provider and tends to be around 1 hour.
ok, err := p.authenticateClient.RefreshSession(session)
// We failed to refresh the session successfully
// clear the cookie and reject the request
if err != nil {
requestLog.Error().Err(err).Str("user", session.Email).Msg("refreshing session failed")
return err
}
if !ok {
// User is not authorized after refresh
// clear the cookie and reject the request
requestLog.Error().Str("user", session.Email).Msg("not authorized after refreshing session")
return ErrUserNotAuthorized
}
err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
// We refreshed the session successfully, but failed to save it.
//
// This could be from failing to encode the session properly.
// But, we clear the session cookie and reject the request!
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save refresh session")
return err
}
} else if session.ValidationPeriodExpired() {
// Validation period has expired, this is the shortest interval we use to
// check for valid requests. This should be set to something like a minute.
// This calls up the provider chain to validate this user is still active
// and hasn't been de-authorized.
ok := p.authenticateClient.ValidateSessionState(session)
if !ok {
// This user is now no longer authorized, or we failed to
// validate the user.
// Clear the cookie and reject the request
requestLog.Error().Str("user", session.Email).Msg("no longer authorized after validation period")
return ErrUserNotAuthorized
}
err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
// We validated the session successfully, but failed to save it.
// This could be from failing to encode the session properly.
// But, we clear the session cookie and reject the request!
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save validated session")
return err
}
}
if !p.EmailValidator(session.Email) {
requestLog.Error().Str("user", session.Email).Msg("email failed to validate, unauthorized")
return ErrUserNotAuthorized
}
req.Header.Set("X-Forwarded-User", session.User)
if p.PassAccessToken && session.AccessToken != "" {
req.Header.Set("X-Forwarded-Access-Token", session.AccessToken)
}
req.Header.Set("X-Forwarded-Email", session.Email)
req.Header.Set("X-Forwarded-Groups", strings.Join(session.Groups, ","))
// stash authenticated user so that it can be logged later (see func logRequest)
rw.Header().Set(loggingUserHeader, session.Email)
// This user has been OK'd. Allow the request!
return nil
}
// upstreamTransport is used to ensure that upstreams cannot override the
// security headers applied by sso_proxy
type upstreamTransport struct {
transport *http.Transport
}
// RoundTrip round trips the request and deletes security headers before returning the response.
func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.transport.RoundTrip(req)
if err != nil {
log.Error().Err(err).Msg("proxy.RoundTrip")
return nil, err
}
for key := range securityHeaders {
resp.Header.Del(key)
}
return resp, err
}
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
func (p *Proxy) Handle(host string, handler http.Handler) {
p.mux[host] = &handler
}
// router attempts to find a route for a request. If a route is successfully matched,
// it returns the route information and a bool value of `true`. If a route can not be matched,
//a nil value for the route and false bool value is returned.
func (p *Proxy) router(req *http.Request) (http.Handler, bool) {
route, ok := p.mux[req.Host]
if ok {
return *route, true
}
return nil, false
}
// GetRedirectURL returns the redirect url for a given Proxy,
// setting the scheme to be https if CookieSecure is true.
func (p *Proxy) GetRedirectURL(host string) *url.URL {
// TODO: Ensure that we only allow valid upstream hosts in redirect URIs
u := p.redirectURL
// Build redirect URI from request host
if u.Scheme == "" {
u.Scheme = "https"
}
u.Host = host
return u
}
func (p *Proxy) redeemCode(host, code string) (*sessions.SessionState, error) {
if code == "" {
return nil, errors.New("missing code")
}
redirectURL := p.GetRedirectURL(host)
s, err := p.authenticateClient.Redeem(redirectURL.String(), code)
if err != nil {
return s, err
}
if s.Email == "" {
return s, errors.New("invalid email address")
}
return s, nil
}

331
proxy/proxy.go Executable file
View file

@ -0,0 +1,331 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"encoding/base64"
"errors"
"fmt"
"html/template"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/pomerium/envconfig"
"github.com/pomerium/pomerium/internal/aead"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/proxy/authenticator"
)
// Options represents the configuration options for the proxy service.
type Options struct {
// AuthenticateServiceURL specifies the url to the pomerium authenticate http service.
AuthenticateServiceURL *url.URL `envconfig:"PROVIDER_URL"`
// EmailDomains is a string slice of valid domains to proxy
EmailDomains []string `envconfig:"EMAIL_DOMAIN"`
// todo(bdd): ClientID and ClientSecret are used are a hacky pre shared key
// prefer certificates and mutual tls
ClientID string `envconfig:"PROXY_CLIENT_ID"`
ClientSecret string `envconfig:"PROXY_CLIENT_SECRET"`
DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT"`
CookieName string `envconfig:"COOKIE_NAME"`
CookieSecret string `envconfig:"COOKIE_SECRET"`
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
CookieSecure bool `envconfig:"COOKIE_SECURE" `
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
PassAccessToken bool `envconfig:"PASS_ACCESS_TOKEN"`
// session details
SessionValidTTL time.Duration `envconfig:"SESSION_VALID_TTL"`
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL"`
GracePeriodTTL time.Duration `envconfig:"GRACE_PERIOD_TTL"`
Routes map[string]string `envconfig:"ROUTES"`
}
// NewOptions returns a new options struct
var defaultOptions = &Options{
CookieName: "_pomerium_proxy",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: time.Duration(168) * time.Hour,
DefaultUpstreamTimeout: time.Duration(10) * time.Second,
SessionLifetimeTTL: time.Duration(720) * time.Hour,
SessionValidTTL: time.Duration(1) * time.Minute,
GracePeriodTTL: time.Duration(3) * time.Hour,
PassAccessToken: false,
}
// OptionsFromEnvConfig builds the authentication service's configuration
// options from provided environmental variables
func OptionsFromEnvConfig() (*Options, error) {
o := defaultOptions
if err := envconfig.Process("", o); err != nil {
return nil, err
}
return o, nil
}
// Validate checks that proper configuration settings are set to create
// a proper Proxy instance
func (o *Options) Validate() error {
if len(o.Routes) == 0 {
return errors.New("missing setting: routes")
}
for to, from := range o.Routes {
if _, err := urlParse(to); err != nil {
return fmt.Errorf("could not parse origin %s as url : %q", to, err)
}
if _, err := urlParse(from); err != nil {
return fmt.Errorf("could not parse destination %s as url : %q", to, err)
}
}
if o.AuthenticateServiceURL == nil {
return errors.New("missing setting: provider-url")
}
if o.CookieSecret == "" {
return errors.New("missing setting: cookie-secret")
}
if o.ClientID == "" {
return errors.New("missing setting: client-id")
}
if o.ClientSecret == "" {
return errors.New("missing setting: client-secret")
}
if len(o.EmailDomains) == 0 {
return errors.New("missing setting: email-domain")
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
if err != nil {
return errors.New("cookie secret is invalid (e.g. `head -c33 /dev/urandom | base64`) ")
}
validCookieSecretLength := false
for _, i := range []int{32, 64} {
if len(decodedCookieSecret) == i {
validCookieSecretLength = true
}
}
if !validCookieSecretLength {
return fmt.Errorf("cookie secret is invalid, must be 32 or 64 bytes but got %d bytes (e.g. `head -c33 /dev/urandom | base64`) ", len(decodedCookieSecret))
}
return nil
}
// Proxy stores all the information associated with proxying a request.
type Proxy struct {
CookieCipher aead.Cipher
CookieDomain string
CookieExpire time.Duration
CookieHTTPOnly bool
CookieName string
CookieSecure bool
CookieSeed string
CSRFCookieName string
EmailValidator func(string) bool
PassAccessToken bool
// services
authenticateClient *authenticator.AuthenticateClient
// session
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
cipher aead.Cipher
redirectURL *url.URL // the url to receive requests at
templates *template.Template
mux map[string]*http.Handler
}
// StateParameter holds the redirect id along with the session id.
type StateParameter struct {
SessionID string `json:"session_id"`
RedirectURI string `json:"redirect_uri"`
}
// NewProxy takes a Proxy service from options and a validation function.
// Function returns an error if options fail to validate.
func NewProxy(opts *Options, optFuncs ...func(*Proxy) error) (*Proxy, error) {
if opts == nil {
return nil, errors.New("options cannot be nil")
}
if err := opts.Validate(); err != nil {
return nil, err
}
// error explicitly handled by validate
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, err := aead.NewMiscreantCipher(decodedSecret)
if err != nil {
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
}
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
sessions.CreateMiscreantCookieCipher(decodedSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
c.CookieHTTPOnly = opts.CookieHTTPOnly
c.CookieExpire = opts.CookieExpire
c.CookieSecure = opts.CookieSecure
return nil
})
if err != nil {
return nil, err
}
authClient := authenticator.NewAuthenticateClient(
opts.AuthenticateServiceURL,
// todo(bdd): fields below can be dropped as Client data provides redudent auth
opts.ClientID,
opts.ClientSecret,
// todo(bdd): fields below should be passed as function args
opts.SessionLifetimeTTL,
opts.SessionValidTTL,
opts.GracePeriodTTL,
)
p := &Proxy{
CookieCipher: cipher,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieName: opts.CookieName,
CookieSecure: opts.CookieSecure,
CookieSeed: string(decodedSecret),
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
// these fields make up the routing mechanism
mux: make(map[string]*http.Handler),
// session state
csrfStore: cookieStore,
sessionStore: cookieStore,
cipher: cipher,
authenticateClient: authClient,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
templates: templates.New(),
PassAccessToken: opts.PassAccessToken,
}
for _, optFunc := range optFuncs {
err := optFunc(p)
if err != nil {
return nil, err
}
}
for from, to := range opts.Routes {
fromURL, _ := urlParse(from)
toURL, _ := urlParse(to)
reverseProxy := NewReverseProxy(toURL)
handler := NewReverseProxyHandler(opts, reverseProxy, toURL.String())
p.Handle(fromURL.Host, handler)
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.NewProxy : route")
}
log.Info().
Str("CookieName", p.CookieName).
Str("redirectURL", p.redirectURL.String()).
Str("CSRFCookieName", p.CSRFCookieName).
Bool("CookieSecure", p.CookieSecure).
Str("CookieDomain", p.CookieDomain).
Bool("CookieHTTPOnly", p.CookieHTTPOnly).
Dur("CookieExpire", opts.CookieExpire).
Msg("proxy.NewProxy")
return p, nil
}
// UpstreamProxy stores information necessary for proxying the request back to the upstream.
type UpstreamProxy struct {
name string
cookieName string
handler http.Handler
}
var defaultUpstreamTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
// deleteSSOCookieHeader deletes the session cookie from the request header string.
func deleteSSOCookieHeader(req *http.Request, cookieName string) {
headers := []string{}
for _, cookie := range req.Cookies() {
if cookie.Name != cookieName {
headers = append(headers, cookie.String())
}
}
req.Header.Set("Cookie", strings.Join(headers, ";"))
}
// ServeHTTP signs the http request and deletes cookie headers
// before calling the upstream's ServeHTTP function.
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
requestLog := log.WithRequest(r, "proxy.ServeHTTP")
deleteSSOCookieHeader(r, u.cookieName)
start := time.Now()
u.handler.ServeHTTP(w, r)
duration := time.Since(start)
requestLog.Debug().Dur("duration", duration).Msg("proxy-request")
}
// NewReverseProxy creates a reverse proxy to a specified url.
// It adds an X-Forwarded-Host header that is the request's host.
//
// todo(bdd): when would we ever want to preserve host?
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(to)
proxy.Transport = defaultUpstreamTransport
director := proxy.Director
proxy.Director = func(req *http.Request) {
req.Header.Add("X-Forwarded-Host", req.Host)
director(req)
req.Host = to.Host
}
return proxy
}
// NewReverseProxyHandler applies handler specific options to a given
// route.
func NewReverseProxyHandler(opts *Options, reverseProxy *httputil.ReverseProxy, serviceName string) http.Handler {
upstreamProxy := &UpstreamProxy{
name: serviceName,
handler: reverseProxy,
cookieName: opts.CookieName,
}
timeout := opts.DefaultUpstreamTimeout
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", serviceName, timeout)
return http.TimeoutHandler(upstreamProxy, timeout, timeoutMsg)
}
// urlParse adds a scheme if none-exists, addressesing a quirk in how
// one may expect url.Parse to function when a "naked" domain is sent.
//
// see: https://github.com/golang/go/issues/12585
// see: https://golang.org/pkg/net/url/#Parse
func urlParse(uri string) (*url.URL, error) {
if !strings.Contains(uri, "://") {
uri = fmt.Sprintf("https://%s", uri)
}
return url.Parse(uri)
}

220
proxy/proxy_test.go Normal file
View file

@ -0,0 +1,220 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"testing"
)
func TestOptionsFromEnvConfig(t *testing.T) {
tests := []struct {
name string
want *Options
envKey string
envValue string
wantErr bool
}{
{"good default, no env settings", defaultOptions, "", "", false},
{"bad url", nil, "PROVIDER_URL", "%.rjlw", true},
{"good duration", defaultOptions, "SESSION_VALID_TTL", "1m", false},
{"bad duration", nil, "SESSION_VALID_TTL", "1sm", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envKey != "" {
os.Setenv(tt.envKey, tt.envValue)
defer os.Unsetenv(tt.envKey)
}
got, err := OptionsFromEnvConfig()
if (err != nil) != tt.wantErr {
t.Errorf("OptionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("OptionsFromEnvConfig() = %v, want %v", got, tt.want)
}
})
}
}
func Test_urlParse(t *testing.T) {
tests := []struct {
name string
uri string
want *url.URL
wantErr bool
}{
{"good url without schema", "accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
{"good url with schema", "https://accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
{"bad url, malformed", "https://accounts.google.^", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := urlParse(tt.uri)
if (err != nil) != tt.wantErr {
t.Errorf("urlParse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("urlParse() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestNewReverseProxyHandler(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
opts := defaultOptions
handle := NewReverseProxyHandler(opts, proxyHandler, "name")
frontend := httptest.NewServer(handle)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func testOptions() *Options {
authurl, _ := url.Parse("https://sso-auth.corp.beyondperimeter.com")
return &Options{
Routes: map[string]string{"corp.example.com": "example.com"},
AuthenticateServiceURL: authurl,
ClientID: "yksYDhIM7PZTvdFP3Mi3sYt2JXooTi7y0oIClBR46fs=",
ClientSecret: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
EmailDomains: []string{"*"},
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
}
}
func TestOptions_Validate(t *testing.T) {
good := testOptions()
badFromRoute := testOptions()
badFromRoute.Routes = map[string]string{"example.com": "^"}
badToRoute := testOptions()
badToRoute.Routes = map[string]string{"^": "example.com"}
badAuthURL := testOptions()
badAuthURL.AuthenticateServiceURL = nil
emptyCookieSecret := testOptions()
emptyCookieSecret.CookieSecret = ""
invalidCookieSecret := testOptions()
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
shortCookieLength := testOptions()
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" //head -c31 /dev/urandom | base64
badClientID := testOptions()
badClientID.ClientID = ""
badClientSecret := testOptions()
badClientSecret.ClientSecret = ""
badEmailDomain := testOptions()
badEmailDomain.EmailDomains = nil
tests := []struct {
name string
o *Options
wantErr bool
}{
{"good - minimum options", good, false},
{"bad - nil options", &Options{}, true},
{"bad - from route", badFromRoute, true},
{"bad - to route", badToRoute, true},
{"bad - auth service url", badAuthURL, true},
{"bad - no cookie secret", emptyCookieSecret, true},
{"bad - invalid cookie secret", invalidCookieSecret, true},
{"bad - short cookie secret", shortCookieLength, true},
{"bad - no client id", badClientID, true},
{"bad - no client secret", badClientSecret, true},
{"bad - no email domain", badEmailDomain, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := tt.o
if err := o.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestNewProxy(t *testing.T) {
good := testOptions()
shortCookieLength := testOptions()
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" //head -c31 /dev/urandom | base64
tests := []struct {
name string
opts *Options
optFuncs []func(*Proxy) error
wantProxy bool
numMuxes int
wantErr bool
}{
{"good - minimum options", good, nil, true, 1, false},
{"bad - empty options", &Options{}, nil, false, 0, true},
{"bad - nil options", nil, nil, false, 0, true},
{"bad - short secret/validate sanity check", shortCookieLength, nil, false, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewProxy(tt.opts, tt.optFuncs...)
if (err != nil) != tt.wantErr {
t.Errorf("NewProxy() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got == nil && tt.wantProxy == true {
t.Errorf("NewProxy() expected valid proxy struct")
}
if got != nil && len(got.mux) != tt.numMuxes {
t.Errorf("NewProxy() = num muxes %v, want %v", got, tt.numMuxes)
}
})
}
}

8
proxy/testdata/public_key.pub vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAst/CEAh/EMnjRbNcwNF7iMqp03En2GYNJz3wfiv/6Rcu7SDgMJke
rYfDcpK8RYREAxyjQpi17eI/FRQx0GbRo1AR0ZgF2VvDTkNBCNb3Pw6bdPbFONCU
JV2WXi/vf+4gMRH0hN00K9ZOz18MaY5va7C0p+xaC5713KNJnOvndo48X+HDICSG
kCyjne/NylEMy1RLwUCdOSZ6SNsTI0tKt95bTEzBhd0GUDfYuG2SoJyLaJisUyW3
8X7TtdRUzSwe6IPeLFppU4QGOf1DI2WlmCdYPPfllCfiqVWMibBzwQZGkBvjWGs3
Cw8iKMKcydVlZCJ8rLIaU6sE/lD1eGrfowIDAQAB
-----END RSA PUBLIC KEY-----

10
proxy/testdata/upstream_configs.yml vendored Normal file
View file

@ -0,0 +1,10 @@
- service: foo
default:
from: foo.{{cluster}}.{{root_domain}}
to: foo-internal.{{cluster}}.{{root_domain}}
options:
allowed_groups:
- dev
dev:
from: foo.{{root_domain}}