mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
initial release
This commit is contained in:
commit
d56c889224
62 changed files with 8229 additions and 0 deletions
96
.gitignore
vendored
Normal file
96
.gitignore
vendored
Normal file
|
@ -0,0 +1,96 @@
|
|||
|
||||
pem
|
||||
env
|
||||
coverage.txt
|
||||
*.pem
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
.cover
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
*.prof
|
||||
|
||||
# Other dirs
|
||||
/bin/
|
||||
/pkg/
|
||||
|
||||
# Without this, the *.[568vq] above ignores this folder.
|
||||
!**/graphrbac/1.6
|
||||
|
||||
# Ruby
|
||||
website/vendor
|
||||
website/.bundle
|
||||
website/build
|
||||
website/tmp
|
||||
|
||||
# Vagrant
|
||||
.vagrant/
|
||||
Vagrantfile
|
||||
|
||||
# Configs
|
||||
*.hcl
|
||||
!command/agent/config/test-fixtures/config.hcl
|
||||
!command/agent/config/test-fixtures/config-embedded-type.hcl
|
||||
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
|
||||
dist/*
|
||||
|
||||
tags
|
||||
|
||||
# Editor backups
|
||||
*~
|
||||
*.sw[a-z]
|
||||
|
||||
# IntelliJ IDEA project files
|
||||
.idea
|
||||
*.ipr
|
||||
*.iml
|
||||
|
||||
# compiled output
|
||||
ui/dist
|
||||
ui/tmp
|
||||
ui/root
|
||||
http/bindata_assetfs.go
|
||||
|
||||
# dependencies
|
||||
ui/node_modules
|
||||
ui/bower_components
|
||||
|
||||
# misc
|
||||
ui/.DS_Store
|
||||
ui/.sass-cache
|
||||
ui/connect.lock
|
||||
ui/coverage/*
|
||||
ui/libpeerconnection.log
|
||||
ui/npm-debug.log
|
||||
ui/testem.log
|
||||
|
||||
# used for JS acceptance tests
|
||||
ui/tests/helpers/vault-keys.js
|
||||
ui/vault-ui-integration-server.pid
|
||||
|
||||
# for building static assets
|
||||
node_modules
|
||||
package-lock.json
|
18
.travis.yml
Normal file
18
.travis.yml
Normal file
|
@ -0,0 +1,18 @@
|
|||
---
|
||||
language: go
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
||||
fast_finish: true
|
||||
install:
|
||||
- go get github.com/golang/lint/golint
|
||||
- go get honnef.co/go/tools/cmd/staticcheck
|
||||
script:
|
||||
- env GO111MODULE=on make all
|
||||
- env GO111MODULE=on make cover
|
||||
- env GO111MODULE=on make release
|
||||
# after_success:
|
||||
# - bash <(curl -s https://codecov.io/bash)
|
88
3RD-PARTY
Normal file
88
3RD-PARTY
Normal file
|
@ -0,0 +1,88 @@
|
|||
Third Party Licenses
|
||||
|
||||
Go
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
https://golang.org/LICENSE
|
||||
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
Upspin
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
https://github.com/upspin/upspin/blob/master/LICENSE
|
||||
|
||||
Copyright (c) 2016 The Upspin Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
buzzfeed/sso (fork of bitly/oauth2_proxy)
|
||||
SPDX-License-Identifier: MIT
|
||||
https://github.com/buzzfeed/sso/blob/master/LICENSE
|
||||
https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
14
Dockerfile
Normal file
14
Dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
|||
FROM golang:alpine as build
|
||||
RUN apk --update --no-cache add ca-certificates git make
|
||||
ENV CGO_ENABLED=0
|
||||
ENV GO111MODULE=on
|
||||
|
||||
WORKDIR /go/src/github.com/pomerium/pomerium
|
||||
COPY . .
|
||||
|
||||
RUN make
|
||||
|
||||
FROM scratch
|
||||
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
WORKDIR /pomerium
|
||||
COPY --from=build /bin/* /bin/
|
201
LICENSE
Normal file
201
LICENSE
Normal file
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2019 Pomerium
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
113
Makefile
Normal file
113
Makefile
Normal file
|
@ -0,0 +1,113 @@
|
|||
# Setup name variables for the package/tool
|
||||
PREFIX?=$(shell pwd)
|
||||
|
||||
NAME := pomerium
|
||||
PKG := github.com/pomerium/$(NAME)
|
||||
|
||||
BUILDDIR := ${PREFIX}/dist
|
||||
BINDIR := ${PREFIX}/bin
|
||||
GO111MODULE=on
|
||||
CGO_ENABLED := 0
|
||||
# Set any default go build tags
|
||||
BUILDTAGS :=
|
||||
|
||||
# Populate version variables
|
||||
# Add to compile time flags
|
||||
VERSION := $(shell cat VERSION)
|
||||
GITCOMMIT := $(shell git rev-parse --short HEAD)
|
||||
GITUNTRACKEDCHANGES := $(shell git status --porcelain --untracked-files=no)
|
||||
BUILDMETA:=
|
||||
ifneq ($(GITUNTRACKEDCHANGES),)
|
||||
BUILDMETA := dirty
|
||||
endif
|
||||
CTIMEVAR=-X $(PKG)/internal/version.GitCommit=$(GITCOMMIT) \
|
||||
-X $(PKG)/internal/version.Version=$(VERSION) \
|
||||
-X $(PKG)/internal/version.BuildMeta=$(BUILDMETA) \
|
||||
-X $(PKG)/internal/version.ProjectName=$(NAME) \
|
||||
-X $(PKG)/internal/version.ProjectURL=$(PKG)
|
||||
GO_LDFLAGS=-ldflags "-w $(CTIMEVAR)"
|
||||
GOOSARCHES = linux/amd64 darwin/amd64 windows/amd64
|
||||
|
||||
|
||||
.PHONY: all
|
||||
all: clean build fmt lint vet test ## Runs a clean, build, fmt, lint, test, and vet.
|
||||
|
||||
|
||||
.PHONY: build
|
||||
build: ## Builds dynamic executables and/or packages.
|
||||
@echo "==> $@"
|
||||
@CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)"
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: ## Verifies all files have been `gofmt`ed.
|
||||
@echo "==> $@"
|
||||
@gofmt -s -l . | grep -v '.pb.go:' | grep -v vendor | tee /dev/stderr
|
||||
|
||||
.PHONY: lint
|
||||
lint: ## Verifies `golint` passes.
|
||||
@echo "==> $@"
|
||||
@golint ./... | grep -v '.pb.go:' | grep -v vendor | tee /dev/stderr
|
||||
|
||||
.PHONY: staticcheck
|
||||
staticcheck: ## Verifies `staticcheck` passes
|
||||
@echo "+ $@"
|
||||
@staticcheck $(shell go list ./... | grep -v vendor) | grep -v '.pb.go:' | tee /dev/stderr
|
||||
|
||||
.PHONY: vet
|
||||
vet: ## Verifies `go vet` passes.
|
||||
@echo "==> $@"
|
||||
@go vet $(shell go list ./... | grep -v vendor) | grep -v '.pb.go:' | tee /dev/stderr
|
||||
|
||||
.PHONY: test
|
||||
test: ## Runs the go tests.
|
||||
@echo "==> $@"
|
||||
@go test -tags "$(BUILDTAGS)" $(shell go list ./... | grep -v vendor)
|
||||
|
||||
|
||||
.PHONY: cover
|
||||
cover: ## Runs go test with coverage
|
||||
@echo "" > coverage.txt
|
||||
@for d in $(shell go list ./... | grep -v vendor); do \
|
||||
go test -race -coverprofile=profile.out -covermode=atomic "$$d"; \
|
||||
if [ -f profile.out ]; then \
|
||||
cat profile.out >> coverage.txt; \
|
||||
rm profile.out; \
|
||||
fi; \
|
||||
done;
|
||||
|
||||
.PHONY: clean
|
||||
clean: ## Cleanup any build binaries or packages.
|
||||
@echo "==> $@"
|
||||
$(RM) -r $(BINDIR)
|
||||
$(RM) -r $(BUILDDIR)
|
||||
|
||||
define buildpretty
|
||||
mkdir -p $(BUILDDIR)/$(1)/$(2);
|
||||
GOOS=$(1) GOARCH=$(2) CGO_ENABLED=0 GO111MODULE=on go build \
|
||||
-o $(BUILDDIR)/$(1)/$(2)/$(NAME) \
|
||||
${GO_LDFLAGS_STATIC} ./cmd/$(NAME);
|
||||
md5sum $(BUILDDIR)/$(1)/$(2)/$(NAME) > $(BUILDDIR)/$(1)/$(2)/$(NAME).md5;
|
||||
sha256sum $(BUILDDIR)/$(1)/$(2)/$(NAME) > $(BUILDDIR)/$(1)/$(2)/$(NAME).sha256;
|
||||
endef
|
||||
|
||||
.PHONY: cross
|
||||
cross: ## Builds the cross-compiled binaries, creating a clean directory structure (eg. GOOS/GOARCH/binary)
|
||||
@echo "+ $@"
|
||||
$(foreach GOOSARCH,$(GOOSARCHES), $(call buildpretty,$(subst /,,$(dir $(GOOSARCH))),$(notdir $(GOOSARCH))))
|
||||
|
||||
define buildrelease
|
||||
GOOS=$(1) GOARCH=$(2) CGO_ENABLED=0 GO111MODULE=on go build \
|
||||
-o $(BUILDDIR)/$(NAME)-$(1)-$(2) \
|
||||
${GO_LDFLAGS_STATIC} ./cmd/$(NAME);
|
||||
md5sum $(BUILDDIR)/$(NAME)-$(1)-$(2) > $(BUILDDIR)/$(NAME)-$(1)-$(2).md5;
|
||||
sha256sum $(BUILDDIR)/$(NAME)-$(1)-$(2) > $(BUILDDIR)/$(NAME)-$(1)-$(2).sha256;
|
||||
endef
|
||||
|
||||
.PHONY: release
|
||||
release: ## Builds the cross-compiled binaries, naming them in such a way for release (eg. binary-GOOS-GOARCH)
|
||||
@echo "+ $@"
|
||||
$(foreach GOOSARCH,$(GOOSARCHES), $(call buildrelease,$(subst /,,$(dir $(GOOSARCH))),$(notdir $(GOOSARCH))))
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
35
README.md
Normal file
35
README.md
Normal file
|
@ -0,0 +1,35 @@
|
|||
<img height="200" src="./docs/logo.png" alt="logo" align="right" >
|
||||
|
||||
# Pomerium : identity-aware access proxy
|
||||
[](https://travis-ci.org/pomerium/pomerium)
|
||||
[](https://goreportcard.com/report/github.com/pomerium/pomerium)
|
||||
[](https://github.com/pomerium/pomerium/blob/master/LICENSE)
|
||||
|
||||
Pomerium is a tool for managing secure access to internal applications and resources.
|
||||
|
||||
Use Pomerium to:
|
||||
|
||||
- provide a unified ingress gateway to internal corporate applications.
|
||||
- enforce dynamic access policies based on context, identity, and device state.
|
||||
- aggregate logging and telemetry data.
|
||||
|
||||
To learn more about zero-trust / BeyondCorp, check out [awesome-zero-trust].
|
||||
|
||||
## Getting started
|
||||
|
||||
For instructions on getting started with Pomerium, see our getting started docs.
|
||||
|
||||
## To start developing Pomerium
|
||||
|
||||
Assuming you have a working [Go environment].
|
||||
|
||||
```sh
|
||||
$ go get -d github.com/pomerium/pomerium
|
||||
$ cd $GOPATH/src/github.com/pomerium/pomerium
|
||||
$ make
|
||||
$ source ./env # see env.example
|
||||
$ ./bin/pomerium -debug
|
||||
```
|
||||
|
||||
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
|
||||
[Go environment]: https://golang.org/doc/install
|
1
VERSION
Normal file
1
VERSION
Normal file
|
@ -0,0 +1 @@
|
|||
v0.0.1
|
254
authenticate/authenticate.go
Normal file
254
authenticate/authenticate.go
Normal file
|
@ -0,0 +1,254 @@
|
|||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/envconfig"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/providers"
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
)
|
||||
|
||||
// Options permits the configuration of the authentication service
|
||||
type Options struct {
|
||||
// e.g.
|
||||
Host string `envconfig:"HOST"`
|
||||
//
|
||||
ProxyClientID string `envconfig:"PROXY_CLIENT_ID"`
|
||||
ProxyClientSecret string `envconfig:"PROXY_CLIENT_SECRET"`
|
||||
|
||||
// Coarse authorization based on user email domain
|
||||
EmailDomains []string `envconfig:"SSO_EMAIL_DOMAIN"`
|
||||
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
|
||||
|
||||
// Session/Cookie management
|
||||
CookieName string
|
||||
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE" default:"168h"`
|
||||
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH" default:"1h"`
|
||||
CookieSecure bool `envconfig:"COOKIE_SECURE" default:"true"`
|
||||
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY" default:"true"`
|
||||
|
||||
AuthCodeSecret string `envconfig:"AUTH_CODE_SECRET"`
|
||||
|
||||
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL" default:"720h"`
|
||||
|
||||
// Authentication provider configuration vars
|
||||
RedirectURL *url.URL `envconfig:"IDP_REDIRECT_URL" ` // e.g. auth.example.com/oauth/callback
|
||||
ClientID string `envconfig:"IDP_CLIENT_ID"` // IdP ClientID
|
||||
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"` // IdP Secret
|
||||
Provider string `envconfig:"IDP_PROVIDER"` //Provider name e.g. "oidc","okta","google",etc
|
||||
ProviderURL *url.URL `envconfig:"IDP_PROVIDER_URL"`
|
||||
Scopes []string `envconfig:"IDP_SCOPE" default:"openid,email,profile"`
|
||||
|
||||
// todo(bdd) : can delete?`
|
||||
ApprovalPrompt string `envconfig:"IDP_APPROVAL_PROMPT" default:"consent"`
|
||||
RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"`
|
||||
RequestTimeout time.Duration `envconfig:"REQUEST_TIMEOUT" default:"2s"`
|
||||
}
|
||||
|
||||
var defaultOptions = &Options{
|
||||
EmailDomains: []string{"*"},
|
||||
CookieName: "_pomerium_authenticate",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
RequestTimeout: time.Duration(2) * time.Second,
|
||||
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
|
||||
ApprovalPrompt: "consent",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
}
|
||||
|
||||
// OptionsFromEnvConfig builds the authentication service's configuration
|
||||
// options from provided environmental variables
|
||||
func OptionsFromEnvConfig() (*Options, error) {
|
||||
o := defaultOptions
|
||||
if err := envconfig.Process("", o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// Validate checks to see if configuration values are valid for authentication service.
|
||||
// The checks do not modify the internal state of the Option structure. Function returns
|
||||
// on first error found.
|
||||
func (o *Options) Validate() error {
|
||||
if o.ProviderURL == nil {
|
||||
return errors.New("missing setting: identity provider url")
|
||||
}
|
||||
if o.RedirectURL == nil {
|
||||
return errors.New("missing setting: identity provider redirect url")
|
||||
}
|
||||
redirectPath := "/oauth2/callback"
|
||||
if o.RedirectURL.Path != redirectPath {
|
||||
return fmt.Errorf("setting redirect-url was %s path should be %s", o.RedirectURL.Path, redirectPath)
|
||||
}
|
||||
if o.ClientID == "" {
|
||||
return errors.New("missing setting: client id")
|
||||
}
|
||||
if o.ClientSecret == "" {
|
||||
return errors.New("missing setting: client secret")
|
||||
}
|
||||
if len(o.EmailDomains) == 0 {
|
||||
return errors.New("missing setting email domain")
|
||||
}
|
||||
if len(o.ProxyRootDomains) == 0 {
|
||||
return errors.New("missing setting: proxy root domain")
|
||||
}
|
||||
if o.ProxyClientID == "" {
|
||||
return errors.New("missing setting: proxy client id")
|
||||
}
|
||||
if o.ProxyClientSecret == "" {
|
||||
return errors.New("missing setting: proxy client secret")
|
||||
}
|
||||
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate options: cookie secret invalid"+
|
||||
"must be a base64-encoded, 256 bit key e.g. `head -c32 /dev/urandom | base64`"+
|
||||
"got %q", err)
|
||||
}
|
||||
|
||||
validCookieSecretLength := false
|
||||
for _, i := range []int{32, 64} {
|
||||
if len(decodedCookieSecret) == i {
|
||||
validCookieSecretLength = true
|
||||
}
|
||||
}
|
||||
|
||||
if !validCookieSecretLength {
|
||||
return fmt.Errorf("authenticate options: invalid cookie secret strength want 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
|
||||
}
|
||||
|
||||
if o.CookieRefresh >= o.CookieExpire {
|
||||
return fmt.Errorf("cookie_refresh (%s) must be less than cookie_expire (%s)", o.CookieRefresh.String(), o.CookieExpire.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticator stores all the information associated with proxying the request.
|
||||
type Authenticator struct {
|
||||
Validator func(string) bool
|
||||
|
||||
EmailDomains []string
|
||||
ProxyRootDomains []string
|
||||
Host string
|
||||
CookieSecure bool
|
||||
|
||||
ProxyClientID string
|
||||
ProxyClientSecret string
|
||||
|
||||
SessionLifetimeTTL time.Duration
|
||||
|
||||
decodedCookieSecret []byte
|
||||
templates *template.Template
|
||||
// sesion related
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
cipher aead.Cipher
|
||||
|
||||
redirectURL *url.URL
|
||||
provider providers.Provider
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a Authenticator struct and applies the optional functions slice to the struct.
|
||||
func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) (*Authenticator, error) {
|
||||
if opts == nil {
|
||||
return nil, errors.New("options cannot be nil")
|
||||
}
|
||||
if err := opts.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedAuthCodeSecret, err := base64.StdEncoding.DecodeString(opts.AuthCodeSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cipher, err := aead.NewMiscreantCipher([]byte(decodedAuthCodeSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
|
||||
sessions.CreateMiscreantCookieCipher(decodedCookieSecret),
|
||||
func(c *sessions.CookieStore) error {
|
||||
c.CookieDomain = opts.CookieDomain
|
||||
c.CookieHTTPOnly = opts.CookieHTTPOnly
|
||||
c.CookieExpire = opts.CookieExpire
|
||||
c.CookieSecure = opts.CookieSecure
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Authenticator{
|
||||
ProxyClientID: opts.ProxyClientID,
|
||||
ProxyClientSecret: opts.ProxyClientSecret,
|
||||
EmailDomains: opts.EmailDomains,
|
||||
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
|
||||
CookieSecure: opts.CookieSecure,
|
||||
redirectURL: opts.RedirectURL,
|
||||
templates: templates.New(),
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
cipher: cipher,
|
||||
}
|
||||
// p.ServeMux = p.Handler()
|
||||
p.provider, err = newProvider(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// apply the option functions
|
||||
for _, optFunc := range optionFuncs {
|
||||
err := optFunc(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newProvider(opts *Options) (providers.Provider, error) {
|
||||
pd := &providers.ProviderData{
|
||||
RedirectURL: opts.RedirectURL,
|
||||
ProviderName: opts.Provider,
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
ApprovalPrompt: opts.ApprovalPrompt,
|
||||
SessionLifetimeTTL: opts.SessionLifetimeTTL,
|
||||
ProviderURL: opts.ProviderURL,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
np, err := providers.New(opts.Provider, pd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providers.NewSingleFlightProvider(np), nil
|
||||
|
||||
}
|
||||
|
||||
func dotPrependDomains(d []string) []string {
|
||||
for i := range d {
|
||||
if d[i] != "" && !strings.HasPrefix(d[i], ".") {
|
||||
d[i] = fmt.Sprintf(".%s", d[i])
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
1097
authenticate/authenticate_test.go
Normal file
1097
authenticate/authenticate_test.go
Normal file
File diff suppressed because it is too large
Load diff
329
authenticate/circuit/breaker.go
Normal file
329
authenticate/circuit/breaker.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
// Package circuit implements the Circuit Breaker pattern.
|
||||
// https://docs.microsoft.com/en-us/azure/architecture/patterns/circuit-breaker
|
||||
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
// State is a type that represents a state of Breaker.
|
||||
type State int
|
||||
|
||||
// These constants are states of Breaker.
|
||||
const (
|
||||
StateClosed State = iota
|
||||
StateHalfOpen
|
||||
StateOpen
|
||||
)
|
||||
|
||||
type (
|
||||
// ShouldTripFunc is a function that takes in a Counts and returns true if the circuit breaker should be tripped.
|
||||
ShouldTripFunc func(Counts) bool
|
||||
// ShouldResetFunc is a function that takes in a Counts and returns true if the circuit breaker should be reset.
|
||||
ShouldResetFunc func(Counts) bool
|
||||
// BackoffDurationFunc is a function that takes in a Counts and returns the backoff duration
|
||||
BackoffDurationFunc func(Counts) time.Duration
|
||||
|
||||
// StateChangeHook is a function that represents a state change.
|
||||
StateChangeHook func(prev, to State)
|
||||
// BackoffHook is a function that represents backoff.
|
||||
BackoffHook func(duration time.Duration, reset time.Time)
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultShouldTripFunc is a default ShouldTripFunc.
|
||||
DefaultShouldTripFunc = func(counts Counts) bool {
|
||||
// Trip into Open after three consecutive failures
|
||||
return counts.ConsecutiveFailures >= 3
|
||||
}
|
||||
// DefaultShouldResetFunc is a default ShouldResetFunc.
|
||||
DefaultShouldResetFunc = func(counts Counts) bool {
|
||||
// Reset after three consecutive successes
|
||||
return counts.ConsecutiveSuccesses >= 3
|
||||
}
|
||||
// DefaultBackoffDurationFunc is an exponential backoff function
|
||||
DefaultBackoffDurationFunc = ExponentialBackoffDuration(time.Duration(100)*time.Second, time.Duration(500)*time.Millisecond)
|
||||
)
|
||||
|
||||
// ErrOpenState is returned when the b state is open
|
||||
type ErrOpenState struct{}
|
||||
|
||||
func (e *ErrOpenState) Error() string { return "circuit breaker is open" }
|
||||
|
||||
// ExponentialBackoffDuration returns a function that uses exponential backoff and full jitter
|
||||
func ExponentialBackoffDuration(maxBackoff, baseTimeout time.Duration) func(Counts) time.Duration {
|
||||
return func(counts Counts) time.Duration {
|
||||
// Full Jitter from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||
// sleep = random_between(0, min(cap, base * 2 ** attempt))
|
||||
backoff := math.Min(float64(maxBackoff), float64(baseTimeout)*math.Exp2(float64(counts.ConsecutiveFailures)))
|
||||
jittered := rand.Float64() * backoff
|
||||
return time.Duration(jittered)
|
||||
}
|
||||
}
|
||||
|
||||
// String implements stringer interface.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
default:
|
||||
return fmt.Sprintf("unknown state: %d", s)
|
||||
}
|
||||
}
|
||||
|
||||
// Counts holds the numbers of requests and their successes/failures.
|
||||
type Counts struct {
|
||||
CurrentRequests int
|
||||
ConsecutiveSuccesses int
|
||||
ConsecutiveFailures int
|
||||
}
|
||||
|
||||
func (c *Counts) onRequest() {
|
||||
c.CurrentRequests++
|
||||
}
|
||||
|
||||
func (c *Counts) afterRequest() {
|
||||
c.CurrentRequests--
|
||||
}
|
||||
|
||||
func (c *Counts) onSuccess() {
|
||||
c.ConsecutiveSuccesses++
|
||||
c.ConsecutiveFailures = 0
|
||||
}
|
||||
|
||||
func (c *Counts) onFailure() {
|
||||
c.ConsecutiveFailures++
|
||||
c.ConsecutiveSuccesses = 0
|
||||
}
|
||||
|
||||
func (c *Counts) clear() {
|
||||
c.ConsecutiveSuccesses = 0
|
||||
c.ConsecutiveFailures = 0
|
||||
}
|
||||
|
||||
// Options configures Breaker:
|
||||
//
|
||||
// HalfOpenConcurrentRequests specifies how many concurrent requests to allow while
|
||||
// the circuit is in the half-open state
|
||||
//
|
||||
// ShouldTripFunc specifies when the circuit should trip from the closed state to
|
||||
// the open state. It takes a Counts struct and returns a bool.
|
||||
//
|
||||
// ShouldResetFunc specifies when the circuit should be reset from the half-open state
|
||||
// to the closed state and allow all requests. It takes a Counts struct and returns a bool.
|
||||
//
|
||||
// BackoffDurationFunc specifies how long to set the backoff duration. It takes a
|
||||
// counts struct and returns a time.Duration
|
||||
//
|
||||
// OnStateChange is called whenever the state of the Breaker changes.
|
||||
//
|
||||
// OnBackoff is called whenever a backoff is set with the backoff duration and reset time
|
||||
//
|
||||
// TestClock is used to mock the clock during tests
|
||||
type Options struct {
|
||||
HalfOpenConcurrentRequests int
|
||||
|
||||
ShouldTripFunc ShouldTripFunc
|
||||
ShouldResetFunc ShouldResetFunc
|
||||
BackoffDurationFunc BackoffDurationFunc
|
||||
|
||||
// hooks
|
||||
OnStateChange StateChangeHook
|
||||
OnBackoff BackoffHook
|
||||
|
||||
// used in tests
|
||||
TestClock clock.Clock
|
||||
}
|
||||
|
||||
// Breaker is a state machine to prevent sending requests that are likely to fail.
|
||||
type Breaker struct {
|
||||
halfOpenRequests int
|
||||
|
||||
shouldTripFunc ShouldTripFunc
|
||||
shouldResetFunc ShouldResetFunc
|
||||
backoffDurationFunc BackoffDurationFunc
|
||||
|
||||
// hooks
|
||||
onStateChange StateChangeHook
|
||||
onBackoff BackoffHook
|
||||
|
||||
// used primarly for mocking tests
|
||||
clock clock.Clock
|
||||
|
||||
mutex sync.Mutex
|
||||
state State
|
||||
counts Counts
|
||||
backoffExpires time.Time
|
||||
generation int
|
||||
}
|
||||
|
||||
// NewBreaker returns a new Breaker configured with the given Settings.
|
||||
func NewBreaker(opts *Options) *Breaker {
|
||||
b := new(Breaker)
|
||||
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
// set hooks
|
||||
b.onStateChange = opts.OnStateChange
|
||||
b.onBackoff = opts.OnBackoff
|
||||
|
||||
b.halfOpenRequests = 1
|
||||
if opts.HalfOpenConcurrentRequests > 0 {
|
||||
b.halfOpenRequests = opts.HalfOpenConcurrentRequests
|
||||
}
|
||||
|
||||
b.backoffDurationFunc = DefaultBackoffDurationFunc
|
||||
if opts.BackoffDurationFunc != nil {
|
||||
b.backoffDurationFunc = opts.BackoffDurationFunc
|
||||
}
|
||||
|
||||
b.shouldTripFunc = DefaultShouldTripFunc
|
||||
if opts.ShouldTripFunc != nil {
|
||||
b.shouldTripFunc = opts.ShouldTripFunc
|
||||
}
|
||||
|
||||
b.shouldResetFunc = DefaultShouldResetFunc
|
||||
if opts.ShouldResetFunc != nil {
|
||||
b.shouldResetFunc = opts.ShouldResetFunc
|
||||
}
|
||||
|
||||
b.clock = clock.New()
|
||||
if opts.TestClock != nil {
|
||||
b.clock = opts.TestClock
|
||||
}
|
||||
|
||||
b.setState(StateClosed)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Call runs the given function if the Breaker allows the call.
|
||||
// Call returns an error instantly if the Breaker rejects the request.
|
||||
// Otherwise, Call returns the result of the request.
|
||||
func (b *Breaker) Call(f func() (interface{}, error)) (interface{}, error) {
|
||||
generation, err := b.beforeRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := f()
|
||||
b.afterRequest(err == nil, generation)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (b *Breaker) beforeRequest() (int, error) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
state, generation := b.currentState()
|
||||
|
||||
switch state {
|
||||
case StateOpen:
|
||||
return generation, &ErrOpenState{}
|
||||
case StateHalfOpen:
|
||||
if b.counts.CurrentRequests >= b.halfOpenRequests {
|
||||
return generation, &ErrOpenState{}
|
||||
}
|
||||
}
|
||||
|
||||
b.counts.onRequest()
|
||||
return generation, nil
|
||||
}
|
||||
|
||||
func (b *Breaker) afterRequest(success bool, prevGeneration int) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
b.counts.afterRequest()
|
||||
|
||||
state, generation := b.currentState()
|
||||
if prevGeneration != generation {
|
||||
return
|
||||
}
|
||||
|
||||
if success {
|
||||
b.onSuccess(state)
|
||||
return
|
||||
}
|
||||
|
||||
b.onFailure(state)
|
||||
}
|
||||
|
||||
func (b *Breaker) onSuccess(state State) {
|
||||
b.counts.onSuccess()
|
||||
switch state {
|
||||
case StateHalfOpen:
|
||||
if b.shouldResetFunc(b.counts) {
|
||||
b.setState(StateClosed)
|
||||
b.counts.clear()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Breaker) onFailure(state State) {
|
||||
b.counts.onFailure()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
if b.shouldTripFunc(b.counts) {
|
||||
b.setState(StateOpen)
|
||||
b.counts.clear()
|
||||
b.setBackoff()
|
||||
}
|
||||
case StateOpen:
|
||||
b.setBackoff()
|
||||
case StateHalfOpen:
|
||||
b.setState(StateOpen)
|
||||
b.setBackoff()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Breaker) setBackoff() {
|
||||
backoffDuration := b.backoffDurationFunc(b.counts)
|
||||
backoffExpires := b.clock.Now().Add(backoffDuration)
|
||||
b.backoffExpires = backoffExpires
|
||||
if b.onBackoff != nil {
|
||||
b.onBackoff(backoffDuration, backoffExpires)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Breaker) currentState() (State, int) {
|
||||
switch b.state {
|
||||
case StateOpen:
|
||||
if b.clock.Now().After(b.backoffExpires) {
|
||||
b.setState(StateHalfOpen)
|
||||
}
|
||||
}
|
||||
return b.state, b.generation
|
||||
}
|
||||
|
||||
func (b *Breaker) newGeneration() {
|
||||
b.generation++
|
||||
}
|
||||
|
||||
func (b *Breaker) setState(state State) {
|
||||
if b.state == state {
|
||||
return
|
||||
}
|
||||
|
||||
b.newGeneration()
|
||||
|
||||
prev := b.state
|
||||
b.state = state
|
||||
|
||||
if b.onStateChange != nil {
|
||||
b.onStateChange(prev, state)
|
||||
}
|
||||
}
|
187
authenticate/circuit/breaker_test.go
Normal file
187
authenticate/circuit/breaker_test.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
var errFailed = errors.New("failed")
|
||||
|
||||
func fail() (interface{}, error) {
|
||||
return nil, errFailed
|
||||
}
|
||||
|
||||
func succeed() (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
mock := clock.NewMock()
|
||||
threshold := 3
|
||||
timeout := time.Duration(2) * time.Second
|
||||
trip := func(c Counts) bool { return c.ConsecutiveFailures > threshold }
|
||||
reset := func(c Counts) bool { return c.ConsecutiveSuccesses > threshold }
|
||||
backoff := func(c Counts) time.Duration { return timeout }
|
||||
stateChange := func(p, c State) { t.Logf("state change from %s to %s\n", p, c) }
|
||||
cb := NewBreaker(&Options{
|
||||
TestClock: mock,
|
||||
ShouldTripFunc: trip,
|
||||
ShouldResetFunc: reset,
|
||||
BackoffDurationFunc: backoff,
|
||||
OnStateChange: stateChange,
|
||||
})
|
||||
state, _ := cb.currentState()
|
||||
if state != StateClosed {
|
||||
t.Fatalf("expected state to start %s, got %s", StateClosed, state)
|
||||
}
|
||||
|
||||
for i := 0; i <= threshold; i++ {
|
||||
_, err := cb.Call(fail)
|
||||
if err == nil {
|
||||
t.Fatalf("expected to error, got nil")
|
||||
}
|
||||
state, _ := cb.currentState()
|
||||
t.Logf("iteration %#v", i)
|
||||
if i == threshold {
|
||||
// we expect this to be the case to trip the circuit
|
||||
if state != StateOpen {
|
||||
t.Fatalf("expected state to be %s, got %s", StateOpen, state)
|
||||
}
|
||||
} else if state != StateClosed {
|
||||
// this is a normal failure case
|
||||
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := cb.Call(fail)
|
||||
switch err.(type) {
|
||||
case *ErrOpenState:
|
||||
// this is the expected case
|
||||
break
|
||||
default:
|
||||
t.Errorf("%#v", cb.counts)
|
||||
t.Fatalf("expected to get open state failure, got %s", err)
|
||||
}
|
||||
|
||||
// we advance time by the timeout and a hair
|
||||
mock.Add(timeout + time.Duration(1)*time.Millisecond)
|
||||
state, _ = cb.currentState()
|
||||
if state != StateHalfOpen {
|
||||
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
||||
}
|
||||
|
||||
for i := 0; i <= threshold; i++ {
|
||||
_, err := cb.Call(succeed)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to get no error, got %s", err)
|
||||
}
|
||||
state, _ := cb.currentState()
|
||||
t.Logf("iteration %#v", i)
|
||||
if i == threshold {
|
||||
// we expect this to be the case that ressets the circuit
|
||||
if state != StateClosed {
|
||||
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||
}
|
||||
} else if state != StateHalfOpen {
|
||||
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
||||
}
|
||||
}
|
||||
|
||||
state, _ = cb.currentState()
|
||||
if state != StateClosed {
|
||||
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExponentialBackOffFunc(t *testing.T) {
|
||||
baseTimeout := time.Duration(1) * time.Millisecond
|
||||
// Note Expected is an upper range case
|
||||
cases := []struct {
|
||||
FailureCount int
|
||||
Expected time.Duration
|
||||
}{
|
||||
{
|
||||
FailureCount: 0,
|
||||
Expected: time.Duration(1) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 1,
|
||||
Expected: time.Duration(2) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 2,
|
||||
Expected: time.Duration(4) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 3,
|
||||
Expected: time.Duration(8) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 4,
|
||||
Expected: time.Duration(16) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 5,
|
||||
Expected: time.Duration(32) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 6,
|
||||
Expected: time.Duration(64) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 7,
|
||||
Expected: time.Duration(128) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 8,
|
||||
Expected: time.Duration(256) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 9,
|
||||
Expected: time.Duration(512) * time.Millisecond,
|
||||
},
|
||||
{
|
||||
FailureCount: 10,
|
||||
Expected: time.Duration(1024) * time.Millisecond,
|
||||
},
|
||||
}
|
||||
|
||||
f := ExponentialBackoffDuration(time.Duration(1)*time.Hour, baseTimeout)
|
||||
for _, tc := range cases {
|
||||
got := f(Counts{ConsecutiveFailures: tc.FailureCount})
|
||||
t.Logf("got backoff %#v", got)
|
||||
if got > tc.Expected {
|
||||
t.Errorf("got %#v but expected less than %#v", got, tc.Expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerClosedParallel(t *testing.T) {
|
||||
cb := NewBreaker(nil)
|
||||
numReqs := 10000
|
||||
wg := &sync.WaitGroup{}
|
||||
routine := func(wg *sync.WaitGroup) {
|
||||
for i := 0; i < numReqs; i++ {
|
||||
cb.Call(succeed)
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
numRoutines := 10
|
||||
for i := 0; i < numRoutines; i++ {
|
||||
wg.Add(1)
|
||||
go routine(wg)
|
||||
}
|
||||
|
||||
total := numReqs * numRoutines
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if cb.counts.ConsecutiveSuccesses != total {
|
||||
t.Fatalf("expected to get total requests %d, got %d", total, cb.counts.ConsecutiveSuccesses)
|
||||
}
|
||||
}
|
629
authenticate/handlers.go
Normal file
629
authenticate/handlers.go
Normal file
|
@ -0,0 +1,629 @@
|
|||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
m "github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
var securityHeaders = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||
"Referrer-Policy": "Same-origin",
|
||||
}
|
||||
|
||||
// Handler returns the Http.Handlers for authentication, callback, and refresh
|
||||
func (p *Authenticator) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
// we setup global endpoints that should respond to any hostname
|
||||
mux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
||||
|
||||
serviceMux := http.NewServeMux()
|
||||
// standard rest and healthcheck endpoints
|
||||
serviceMux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
||||
serviceMux.HandleFunc("/robots.txt", m.WithMethods(p.RobotsTxt, "GET"))
|
||||
// Identity Provider (IdP) endpoints and callbacks
|
||||
serviceMux.HandleFunc("/start", m.WithMethods(p.OAuthStart, "GET"))
|
||||
serviceMux.HandleFunc("/oauth2/callback", m.WithMethods(p.OAuthCallback, "GET"))
|
||||
// authenticator-server endpoints, todo(bdd): make gRPC
|
||||
serviceMux.HandleFunc("/sign_in", m.WithMethods(m.ValidateClientID(p.validateSignature(p.SignIn), p.ProxyClientID), "GET"))
|
||||
serviceMux.HandleFunc("/sign_out", m.WithMethods(p.validateSignature(p.SignOut), "GET", "POST"))
|
||||
serviceMux.HandleFunc("/profile", m.WithMethods(p.validateExisting(p.GetProfile), "GET"))
|
||||
serviceMux.HandleFunc("/validate", m.WithMethods(p.validateExisting(p.ValidateToken), "GET"))
|
||||
serviceMux.HandleFunc("/redeem", m.WithMethods(p.validateExisting(p.Redeem), "POST"))
|
||||
serviceMux.HandleFunc("/refresh", m.WithMethods(p.validateExisting(p.Refresh), "POST"))
|
||||
|
||||
// NOTE: we have to include trailing slash for the router to match the host header
|
||||
host := p.Host
|
||||
if !strings.HasSuffix(host, "/") {
|
||||
host = fmt.Sprintf("%s/", host)
|
||||
}
|
||||
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
|
||||
|
||||
return m.SetHeaders(mux, securityHeaders)
|
||||
}
|
||||
|
||||
// validateSignature wraps a common collection of middlewares to validate signatures
|
||||
func (p *Authenticator) validateSignature(f http.HandlerFunc) http.HandlerFunc {
|
||||
return validateRedirectURI(validateSignature(f, p.ProxyClientSecret), p.ProxyRootDomains)
|
||||
|
||||
}
|
||||
|
||||
// validateSignature wraps a common collection of middlewares to validate
|
||||
// a (presumably) existing user session
|
||||
func (p *Authenticator) validateExisting(f http.HandlerFunc) http.HandlerFunc {
|
||||
return m.ValidateClientID(m.ValidateClientSecret(f, p.ProxyClientSecret), p.ProxyClientID)
|
||||
}
|
||||
|
||||
// RobotsTxt handles the /robots.txt route.
|
||||
func (p *Authenticator) RobotsTxt(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
// PingPage handles the /ping route
|
||||
func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(rw, "OK")
|
||||
}
|
||||
|
||||
// SignInPage directs the user to the sign in page
|
||||
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||
requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
||||
rw.WriteHeader(code)
|
||||
redirectURL := p.redirectURL.ResolveReference(req.URL)
|
||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||
t := struct {
|
||||
ProviderName string
|
||||
EmailDomains []string
|
||||
Redirect string
|
||||
Destination string
|
||||
Version string
|
||||
}{
|
||||
ProviderName: p.provider.Data().ProviderName,
|
||||
EmailDomains: p.EmailDomains,
|
||||
Redirect: redirectURL.String(),
|
||||
Destination: destinationURL.Host,
|
||||
Version: version.FullVersion(),
|
||||
}
|
||||
requestLog.Info().
|
||||
Str("ProviderName", p.provider.Data().ProviderName).
|
||||
Str("Redirect", redirectURL.String()).
|
||||
Str("Destination", destinationURL.Host).
|
||||
Str("EmailDomains", strings.Join(p.EmailDomains, ", ")).
|
||||
Msg("authenticate.SignInPage")
|
||||
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
||||
}
|
||||
|
||||
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
|
||||
requestLog := log.WithRequest(req, "authenticate.authenticate")
|
||||
|
||||
session, err := p.sessionStore.LoadSession(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate.authenticate")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ensure sessions lifetime has not expired
|
||||
if session.LifetimePeriodExpired() {
|
||||
requestLog.Warn().Msg("lifetime expired")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, sessions.ErrLifetimeExpired
|
||||
}
|
||||
// check if session refresh period is up
|
||||
if session.RefreshPeriodExpired() {
|
||||
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("failed to refresh session")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
requestLog.Error().Msg("user unauthorized after refresh")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
// update refresh'd session in cookie
|
||||
err = p.sessionStore.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
// We refreshed the session successfully, but failed to save it.
|
||||
// This could be from failing to encode the session properly.
|
||||
// But, we clear the session cookie and reject the request
|
||||
requestLog.Error().Err(err).Msg("could not save refreshed session")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// The session has not exceeded it's lifetime or requires refresh
|
||||
ok := p.provider.ValidateSessionState(session)
|
||||
if !ok {
|
||||
requestLog.Error().Msg("invalid session state")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
err = p.sessionStore.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("failed to save valid session")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !p.Validator(session.Email) {
|
||||
requestLog.Error().Msg("invalid email user")
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
||||
// and if the user is not authenticated, it renders a sign in page.
|
||||
func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
// We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
|
||||
// page.
|
||||
//
|
||||
// If the user is authenticated, we redirect back to the proxy application
|
||||
// at the `redirect_uri`, with a temporary token.
|
||||
//
|
||||
// TODO: It is possible for a user to visit this page without a redirect destination.
|
||||
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
|
||||
|
||||
session, err := p.authenticate(rw, req)
|
||||
switch err {
|
||||
case nil:
|
||||
// User is authenticated, redirect back to the proxy application
|
||||
// with the necessary state
|
||||
p.ProxyOAuthRedirect(rw, req, session)
|
||||
case http.ErrNoCookie:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
||||
p.SignInPage(rw, req, http.StatusOK)
|
||||
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
p.SignInPage(rw, req, http.StatusOK)
|
||||
default:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
||||
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyOAuthRedirect redirects the user back to sso proxy's redirection endpoint.
|
||||
func (p *Authenticator) ProxyOAuthRedirect(rw http.ResponseWriter, req *http.Request, session *sessions.SessionState) {
|
||||
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
||||
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
||||
//
|
||||
// We redirect the user back to the proxy application's redirection endpoint; in the
|
||||
// sso proxy, this is the `/oauth/callback` endpoint.
|
||||
//
|
||||
// We must provide the proxy with a temporary authorization code via the `code` parameter,
|
||||
// which they can use to redeem an access token for subsequent API calls.
|
||||
//
|
||||
// We must also include the original `state` parameter received from the proxy application.
|
||||
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
state := req.Form.Get("state")
|
||||
if state == "" {
|
||||
httputil.ErrorResponse(rw, req, "no state parameter supplied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if redirectURI == "" {
|
||||
httputil.ErrorResponse(rw, req, "no redirect_uri parameter supplied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, "malformed redirect_uri parameter passed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
encrypted, err := sessions.MarshalSession(session, p.cipher)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound)
|
||||
}
|
||||
|
||||
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
|
||||
u, _ := url.Parse(redirectURL.String())
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
params.Set("code", authCode)
|
||||
params.Set("state", state)
|
||||
|
||||
u.RawQuery = params.Encode()
|
||||
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// SignOut signs the user out.
|
||||
func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if req.Method == "GET" {
|
||||
p.SignOutPage(rw, req, "")
|
||||
return
|
||||
}
|
||||
|
||||
session, err := p.sessionStore.LoadSession(req)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
|
||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||
return
|
||||
default:
|
||||
// a different error, clear the session cookie and redirect
|
||||
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.provider.Revoke(session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
|
||||
p.SignOutPage(rw, req, "An error occurred during sign out. Please try again.")
|
||||
return
|
||||
}
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// SignOutPage renders a sign out page with a message
|
||||
func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, message string) {
|
||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
session, err := p.sessionStore.LoadSession(req)
|
||||
if err != nil {
|
||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
signature := req.Form.Get("sig")
|
||||
timestamp := req.Form.Get("ts")
|
||||
destinationURL, _ := url.Parse(redirectURI)
|
||||
|
||||
// An error message indicates that an internal server error occurred
|
||||
if message != "" {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
t := struct {
|
||||
Redirect string
|
||||
Signature string
|
||||
Timestamp string
|
||||
Message string
|
||||
Destination string
|
||||
Email string
|
||||
Version string
|
||||
}{
|
||||
Redirect: redirectURI,
|
||||
Signature: signature,
|
||||
Timestamp: timestamp,
|
||||
Message: message,
|
||||
Destination: destinationURL.Host,
|
||||
Email: session.Email,
|
||||
Version: version.FullVersion(),
|
||||
}
|
||||
p.templates.ExecuteTemplate(rw, "sign_out.html", t)
|
||||
return
|
||||
}
|
||||
|
||||
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
||||
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
nonce := fmt.Sprintf("%x", aead.GenerateKey())
|
||||
p.csrfStore.SetCSRF(rw, req, nonce)
|
||||
|
||||
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
|
||||
if err != nil || !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
||||
ts := authRedirectURL.Query().Get("ts")
|
||||
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.ProxyClientSecret) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
||||
|
||||
signInURL := p.provider.GetSignInURL(state)
|
||||
|
||||
http.Redirect(rw, req, signInURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, error) {
|
||||
session, err := p.provider.Redeem(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if session.Email == "" {
|
||||
return nil, fmt.Errorf("no email included in session")
|
||||
}
|
||||
|
||||
return session, nil
|
||||
|
||||
}
|
||||
|
||||
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||
func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Request) (string, error) {
|
||||
requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
|
||||
// finish the oauth cycle
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||
}
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
|
||||
}
|
||||
code := req.Form.Get("code")
|
||||
if code == "" {
|
||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
||||
}
|
||||
|
||||
session, err := p.redeemCode(req.Host, code)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("error redeeming authentication code")
|
||||
return "", err
|
||||
}
|
||||
|
||||
bytes, err := base64.URLEncoding.DecodeString(req.Form.Get("state"))
|
||||
if err != nil {
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||
}
|
||||
s := strings.SplitN(string(bytes), ":", 2)
|
||||
if len(s) != 2 {
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||
}
|
||||
nonce := s[0]
|
||||
redirect := s[1]
|
||||
c, err := p.csrfStore.GetCSRF(req)
|
||||
if err != nil {
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
||||
}
|
||||
p.csrfStore.ClearCSRF(rw, req)
|
||||
if c.Value != nonce {
|
||||
requestLog.Error().Err(err).Msg("csrf token mismatch")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
|
||||
}
|
||||
|
||||
if !validRedirectURI(redirect, p.ProxyRootDomains) {
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
||||
}
|
||||
|
||||
// Set cookie, or deny: The authenticator validates the session email and group
|
||||
// - for p.Validator see validator.go#newValidatorImpl for more info
|
||||
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
||||
if !p.Validator(session.Email) {
|
||||
requestLog.Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
||||
}
|
||||
requestLog.Info().Str("email", session.Email).Msg("authentication complete")
|
||||
err = p.sessionStore.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("internal error")
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
|
||||
}
|
||||
return redirect, nil
|
||||
}
|
||||
|
||||
// OAuthCallback handles the callback from the provider, and returns an error response if there is an error.
|
||||
// If there is no error it will redirect to the redirect url.
|
||||
func (p *Authenticator) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.getOAuthCallback(rw, req)
|
||||
switch h := err.(type) {
|
||||
case nil:
|
||||
break
|
||||
case httputil.HTTPError:
|
||||
httputil.ErrorResponse(rw, req, h.Message, h.Code)
|
||||
return
|
||||
default:
|
||||
httputil.ErrorResponse(rw, req, "Internal Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, redirect, http.StatusFound)
|
||||
}
|
||||
|
||||
// Redeem has a signed access token, and provides the user information associated with the access token.
|
||||
func (p *Authenticator) Redeem(rw http.ResponseWriter, req *http.Request) {
|
||||
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
||||
// expiration, and email.
|
||||
requestLog := log.WithRequest(req, "authenticate.Redeem")
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := sessions.UnmarshalSession(req.Form.Get("code"), p.cipher)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
|
||||
http.Error(rw, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
|
||||
http.Error(rw, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
||||
requestLog.Error().Msg("expired session")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
http.Error(rw, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
response := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Email string `json:"email"`
|
||||
}{
|
||||
AccessToken: session.AccessToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
IDToken: session.IDToken,
|
||||
ExpiresIn: int64(session.RefreshDeadline.Sub(time.Now()).Seconds()),
|
||||
Email: session.Email,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
rw.Header().Set("GAP-Auth", session.Email)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(jsonBytes)
|
||||
|
||||
}
|
||||
|
||||
// Refresh takes a refresh token and returns a new access token
|
||||
func (p *Authenticator) Refresh(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := req.Form.Get("refresh_token")
|
||||
if refreshToken == "" {
|
||||
http.Error(rw, "Bad Request: No Refresh Token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}{
|
||||
AccessToken: accessToken,
|
||||
ExpiresIn: int64(expiresIn.Seconds()),
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(bytes)
|
||||
}
|
||||
|
||||
// GetProfile gets a list of groups of which a user is a member.
|
||||
func (p *Authenticator) GetProfile(rw http.ResponseWriter, req *http.Request) {
|
||||
// The sso proxy sends the user's email to this endpoint to get a list of Google groups that
|
||||
// the email is a member of. The proxy will compare these groups to the list of allowed
|
||||
// groups for the upstream service the user is trying to access.
|
||||
|
||||
email := req.FormValue("email")
|
||||
if email == "" {
|
||||
http.Error(rw, "no email address included", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// groupsFormValue := req.FormValue("groups")
|
||||
// allowedGroups := []string{}
|
||||
// if groupsFormValue != "" {
|
||||
// allowedGroups = strings.Split(groupsFormValue, ",")
|
||||
// }
|
||||
|
||||
// groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
|
||||
// if err != nil {
|
||||
// log.Error().Err(err).Msg("authenticate.GetProfile : error retrieving groups")
|
||||
// httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||
// return
|
||||
// }
|
||||
|
||||
response := struct {
|
||||
Email string `json:"email"`
|
||||
}{
|
||||
Email: email,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("error marshaling response: %s", err.Error()), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
rw.Header().Set("GAP-Auth", email)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(jsonBytes)
|
||||
}
|
||||
|
||||
// ValidateToken validates the X-Access-Token from the header and returns an error response
|
||||
// if it's invalid
|
||||
func (p *Authenticator) ValidateToken(rw http.ResponseWriter, req *http.Request) {
|
||||
accessToken := req.Header.Get("X-Access-Token")
|
||||
idToken := req.Header.Get("X-Id-Token")
|
||||
|
||||
if accessToken == "" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if idToken == "" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ok := p.provider.ValidateSessionState(&sessions.SessionState{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
})
|
||||
|
||||
if !ok {
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
98
authenticate/middleware.go
Normal file
98
authenticate/middleware.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the url's domain is one in the list of proxy root domains.
|
||||
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||
redirectURL, err := url.Parse(uri)
|
||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
for _, domain := range rootDomains {
|
||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateSignature(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
sigVal := req.Form.Get("sig")
|
||||
timestamp := req.Form.Get("ts")
|
||||
if !validSignature(redirectURI, sigVal, timestamp, proxyClientSecret) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
}
|
||||
_, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
i, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
tm := time.Unix(i, 0)
|
||||
ttl := 5 * time.Minute
|
||||
if time.Now().Sub(tm) > ttl {
|
||||
return false
|
||||
}
|
||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||
return hmac.Equal(requestSig, localSig)
|
||||
}
|
||||
|
||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(rawRedirect))
|
||||
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||
return h.Sum(nil)
|
||||
}
|
100
authenticate/providers/google.go
Normal file
100
authenticate/providers/google.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// GoogleProvider is an implementation of the Provider interface.
|
||||
type GoogleProvider struct {
|
||||
*ProviderData
|
||||
cb *circuit.Breaker
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
|
||||
func NewGoogleProvider(p *ProviderData) (*GoogleProvider, error) {
|
||||
ctx := context.Background()
|
||||
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
googleProvider := &GoogleProvider{
|
||||
ProviderData: p,
|
||||
}
|
||||
// google supports a revokation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
if err := provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
googleProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
googleProvider.cb = circuit.NewBreaker(&circuit.Options{
|
||||
HalfOpenConcurrentRequests: 2,
|
||||
OnStateChange: googleProvider.cbStateChange,
|
||||
OnBackoff: googleProvider.cbBackoff,
|
||||
ShouldTripFunc: func(c circuit.Counts) bool { return c.ConsecutiveFailures >= 3 },
|
||||
ShouldResetFunc: func(c circuit.Counts) bool { return c.ConsecutiveSuccesses >= 6 },
|
||||
BackoffDurationFunc: circuit.ExponentialBackoffDuration(
|
||||
time.Duration(200)*time.Second,
|
||||
time.Duration(500)*time.Millisecond),
|
||||
})
|
||||
|
||||
return googleProvider, nil
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) cbBackoff(duration time.Duration, reset time.Time) {
|
||||
log.Info().Dur("duration", duration).Msg("authenticate/providers/google.cbBackoff")
|
||||
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) cbStateChange(from, to circuit.State) {
|
||||
log.Info().Str("from", from.String()).Str("to", to.String()).Msg("authenticate/providers/google.cbStateChange")
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
//
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
|
||||
// https://github.com/googleapis/google-api-dotnet-client/issues/1285
|
||||
func (p *GoogleProvider) Revoke(s *sessions.SessionState) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", s.AccessToken)
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
// Google requires access type offline
|
||||
func (p *GoogleProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
}
|
32
authenticate/providers/oidc.go
Normal file
32
authenticate/providers/oidc.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDCProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
type OIDCProvider struct {
|
||||
*ProviderData
|
||||
}
|
||||
|
||||
// NewOIDCProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOIDCProvider(p *ProviderData) (*OIDCProvider, error) {
|
||||
ctx := context.Background()
|
||||
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
return &OIDCProvider{ProviderData: p}, nil
|
||||
}
|
69
authenticate/providers/okta.go
Normal file
69
authenticate/providers/okta.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OktaProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
type OktaProvider struct {
|
||||
*ProviderData
|
||||
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
}
|
||||
|
||||
// NewOktaProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOktaProvider(p *ProviderData) (*OktaProvider, error) {
|
||||
ctx := context.Background()
|
||||
provider, err := oidc.NewProvider(ctx, "https://dev-108295.oktapreview.com/oauth2/default")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||
p.oauth = &oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: p.RedirectURL.String(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
oktaProvider := OktaProvider{ProviderData: p}
|
||||
|
||||
// okta supports a revokation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
if err := provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oktaProvider, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://developer.okta.com/docs/api/resources/oidc#revoke
|
||||
func (p *OktaProvider) Revoke(s *sessions.SessionState) error {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("token", s.IDToken)
|
||||
params.Add("token_type_hint", "refresh_token")
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
256
authenticate/providers/providers.go
Normal file
256
authenticate/providers/providers.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// GoogleProviderName identifies the Google provider
|
||||
GoogleProviderName = "google"
|
||||
// OIDCProviderName identifes a generic OpenID connect provider
|
||||
OIDCProviderName = "oidc"
|
||||
// OktaProviderName identifes the Okta identity provider
|
||||
OktaProviderName = "okta"
|
||||
)
|
||||
|
||||
// Provider is an interface exposing functions necessary to authenticate with a given provider.
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
Redeem(string) (*sessions.SessionState, error)
|
||||
ValidateSessionState(*sessions.SessionState) bool
|
||||
GetSignInURL(state string) string
|
||||
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||
Revoke(*sessions.SessionState) error
|
||||
RefreshAccessToken(string) (string, time.Duration, error)
|
||||
// Stop()
|
||||
}
|
||||
|
||||
// New returns a new identity provider based on available name.
|
||||
// Defaults to google.
|
||||
func New(provider string, p *ProviderData) (Provider, error) {
|
||||
switch provider {
|
||||
case OIDCProviderName:
|
||||
p, err := NewOIDCProvider(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
case OktaProviderName:
|
||||
log.Info().Msg("Okta!")
|
||||
p, err := NewOktaProvider(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
default:
|
||||
p, err := NewGoogleProvider(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderData holds the fields associated with providers
|
||||
// necessary to implement the Provider interface.
|
||||
type ProviderData struct {
|
||||
RedirectURL *url.URL
|
||||
ProviderName string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ProviderURL *url.URL
|
||||
Scopes []string
|
||||
ApprovalPrompt string
|
||||
SessionLifetimeTTL time.Duration
|
||||
|
||||
verifier *oidc.IDTokenVerifier
|
||||
oauth *oauth2.Config
|
||||
}
|
||||
|
||||
// Data returns a ProviderData.
|
||||
func (p *ProviderData) Data() *ProviderData { return p }
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
func (p *ProviderData) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// ValidateSessionState validates a given session's from it's JWT token
|
||||
// The function verifies it's been signed by the provider, preforms
|
||||
// any additional checks depending on the Config, and returns the payload.
|
||||
//
|
||||
// ValidateSessionState does NOT do nonce validation.
|
||||
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
ctx := context.Background()
|
||||
_, err := p.verifier.Verify(ctx, s.IDToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.ValidateSessionState : failed to verify session state")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Redeem creates a session with an identity provider from a authorization code
|
||||
func (p *ProviderData) Redeem(code string) (*sessions.SessionState, error) {
|
||||
ctx := context.Background()
|
||||
// convert authorization code into a token
|
||||
token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.Redeem : token exchange failed")
|
||||
return nil, fmt.Errorf("token exchange: %v", err)
|
||||
}
|
||||
s, err := p.createSessionState(ctx, token)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.Redeem : unable to update session")
|
||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded will refresh the session state if it's deadline is expired
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if !sessionRefreshRequired(s) {
|
||||
log.Info().Msg("authenticate/providers.RefreshSessionIfNeeded : session refresh not needed")
|
||||
return false, nil
|
||||
}
|
||||
origExpiration := s.RefreshDeadline
|
||||
err := p.redeemRefreshToken(s)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.RefreshSession")
|
||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
|
||||
log.Info().Msgf("authenticate/providers.Redeem refreshed id token %s (expired on %s)", s, origExpiration)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *ProviderData) redeemRefreshToken(s *sessions.SessionState) error {
|
||||
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 1")
|
||||
ctx := context.Background()
|
||||
t := &oauth2.Token{
|
||||
RefreshToken: s.RefreshToken,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 3")
|
||||
|
||||
// returns a TokenSource automatically refreshing it as necessary using the provided context
|
||||
token, err := p.oauth.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers failed to get token")
|
||||
return fmt.Errorf("failed to get token: %v", err)
|
||||
}
|
||||
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 4")
|
||||
|
||||
newSession, err := p.createSessionState(ctx, token)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers unable to update session")
|
||||
return fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
s.AccessToken = newSession.AccessToken
|
||||
s.IDToken = newSession.IDToken
|
||||
s.RefreshToken = newSession.RefreshToken
|
||||
s.RefreshDeadline = newSession.RefreshDeadline
|
||||
s.Email = newSession.Email
|
||||
|
||||
log.Info().
|
||||
Str("AccessToken", s.AccessToken).
|
||||
Str("IdToken", s.IDToken).
|
||||
Time("RefreshDeadline", s.RefreshDeadline).
|
||||
Str("RefreshToken", s.RefreshToken).
|
||||
Str("Email", s.Email).
|
||||
Msg("authenticate/providers.redeemRefreshToken")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProviderData) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
log.Info().
|
||||
Bool("ctx", ctx == nil).
|
||||
Bool("Verifier", p.verifier == nil).
|
||||
Str("rawIDToken", rawIDToken).
|
||||
Msg("authenticate/providers.oidc.createSessionState 2")
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers could not verify id_token")
|
||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||
}
|
||||
|
||||
// Extract custom claims.
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
}
|
||||
// parse claims from the raw, encoded jwt token
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse id_token claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.Email == "" {
|
||||
return nil, fmt.Errorf("id_token did not contain an email")
|
||||
}
|
||||
if claims.Verified != nil && !*claims.Verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
}
|
||||
|
||||
return &sessions.SessionState{
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
RefreshDeadline: token.Expiry,
|
||||
LifetimeDeadline: token.Expiry,
|
||||
Email: claims.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken allows the service to refresh an access token without
|
||||
// prompting the user for permission.
|
||||
func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
|
||||
if refreshToken == "" {
|
||||
return "", 0, errors.New("missing refresh token")
|
||||
}
|
||||
ctx := context.Background()
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{TokenURL: p.ProviderURL.String()},
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: refreshToken}
|
||||
ts := c.TokenSource(ctx, &t)
|
||||
log.Info().
|
||||
Str("RefreshToken", refreshToken).
|
||||
Msg("authenticate/providers.RefreshAccessToken")
|
||||
|
||||
newToken, err := ts.Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
|
||||
return "", 0, err
|
||||
}
|
||||
return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil
|
||||
}
|
||||
|
||||
// Revoke enables a user to revoke her tokenn. Though many providers such as
|
||||
// google and okta provide revoke endpoints, since it's not officially supported
|
||||
// as part of OpenID Connect, the default implementation throws an error.
|
||||
func (p *ProviderData) Revoke(s *sessions.SessionState) error {
|
||||
return errors.New("revoke not implemented")
|
||||
}
|
||||
|
||||
func sessionRefreshRequired(s *sessions.SessionState) bool {
|
||||
return s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == ""
|
||||
|
||||
}
|
142
authenticate/providers/singleflight_middleware.go
Normal file
142
authenticate/providers/singleflight_middleware.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Provider = &SingleFlightProvider{}
|
||||
)
|
||||
|
||||
// ErrUnexpectedReturnType is an error for an unexpected return type
|
||||
var (
|
||||
ErrUnexpectedReturnType = errors.New("received unexpected return type from single flight func call")
|
||||
)
|
||||
|
||||
// SingleFlightProvider middleware provider that multiple requests for the same object
|
||||
// to be processed as a single request. This is often called request collpasing or coalesce.
|
||||
// This middleware leverages the golang singlelflight provider, with modifications for metrics.
|
||||
//
|
||||
// It's common among HTTP reverse proxy cache servers such as nginx, Squid or Varnish - they all call it something else but works similarly.
|
||||
//
|
||||
// * https://www.varnish-cache.org/docs/3.0/tutorial/handling_misbehaving_servers.html
|
||||
// * http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_cache_lock
|
||||
// * http://wiki.squid-cache.org/Features/CollapsedForwarding
|
||||
type SingleFlightProvider struct {
|
||||
provider Provider
|
||||
|
||||
single *singleflight.Group
|
||||
}
|
||||
|
||||
// NewSingleFlightProvider returns a new SingleFlightProvider
|
||||
func NewSingleFlightProvider(provider Provider) *SingleFlightProvider {
|
||||
return &SingleFlightProvider{
|
||||
provider: provider,
|
||||
single: &singleflight.Group{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
compositeKey := fmt.Sprintf("%s/%s", endpoint, key)
|
||||
resp, _, err := p.single.Do(compositeKey, fn)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Data returns the provider data
|
||||
func (p *SingleFlightProvider) Data() *ProviderData {
|
||||
return p.provider.Data()
|
||||
}
|
||||
|
||||
// Redeem wraps the provider's Redeem function.
|
||||
func (p *SingleFlightProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||
return p.provider.Redeem(code)
|
||||
}
|
||||
|
||||
// ValidateSessionState wraps the provider's ValidateSessionState in a single flight call.
|
||||
func (p *SingleFlightProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) {
|
||||
valid := p.provider.ValidateSessionState(s)
|
||||
return valid, nil
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
valid, ok := response.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return valid
|
||||
}
|
||||
|
||||
// GetSignInURL calls the provider's GetSignInURL function.
|
||||
func (p *SingleFlightProvider) GetSignInURL(finalRedirect string) string {
|
||||
return p.provider.GetSignInURL(finalRedirect)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded wraps the provider's RefreshSessionIfNeeded function in a single flight
|
||||
// call.
|
||||
func (p *SingleFlightProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
response, err := p.do("RefreshSessionIfNeeded", s.RefreshToken, func() (interface{}, error) {
|
||||
return p.provider.RefreshSessionIfNeeded(s)
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
r, ok := response.(bool)
|
||||
if !ok {
|
||||
return false, ErrUnexpectedReturnType
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Revoke wraps the provider's Revoke function in a single flight call.
|
||||
func (p *SingleFlightProvider) Revoke(s *sessions.SessionState) error {
|
||||
_, err := p.do("Revoke", s.AccessToken, func() (interface{}, error) {
|
||||
err := p.provider.Revoke(s)
|
||||
return nil, err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// RefreshAccessToken wraps the provider's RefreshAccessToken function in a single flight call.
|
||||
func (p *SingleFlightProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
|
||||
type Response struct {
|
||||
AccessToken string
|
||||
ExpiresIn time.Duration
|
||||
}
|
||||
response, err := p.do("RefreshAccessToken", refreshToken, func() (interface{}, error) {
|
||||
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Response{
|
||||
AccessToken: accessToken,
|
||||
ExpiresIn: expiresIn,
|
||||
}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
r, ok := response.(*Response)
|
||||
if !ok {
|
||||
return "", 0, ErrUnexpectedReturnType
|
||||
}
|
||||
|
||||
return r.AccessToken, r.ExpiresIn, nil
|
||||
}
|
||||
|
||||
// // Stop calls the provider's stop function
|
||||
// func (p *SingleFlightProvider) Stop() {
|
||||
// p.provider.Stop()
|
||||
// }
|
81
authenticate/providers/test_provider.go
Normal file
81
authenticate/providers/test_provider.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
// TestProvider is a test implementation of the Provider interface.
|
||||
type TestProvider struct {
|
||||
*ProviderData
|
||||
|
||||
ValidToken bool
|
||||
ValidGroup bool
|
||||
SignInURL string
|
||||
Refresh bool
|
||||
RefreshFunc func(string) (string, time.Duration, error)
|
||||
RefreshError error
|
||||
Session *sessions.SessionState
|
||||
RedeemError error
|
||||
RevokeError error
|
||||
Groups []string
|
||||
GroupsError error
|
||||
GroupsCall int
|
||||
}
|
||||
|
||||
// NewTestProvider creates a new mock test provider.
|
||||
func NewTestProvider(providerURL *url.URL) *TestProvider {
|
||||
return &TestProvider{
|
||||
ProviderData: &ProviderData{
|
||||
ProviderName: "Test Provider",
|
||||
ProviderURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: providerURL.Host,
|
||||
Path: "/authorize",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateSessionState returns the mock provider's ValidToken field value.
|
||||
func (tp *TestProvider) ValidateSessionState(*sessions.SessionState) bool {
|
||||
return tp.ValidToken
|
||||
}
|
||||
|
||||
// GetSignInURL returns the mock provider's SignInURL field value.
|
||||
func (tp *TestProvider) GetSignInURL(finalRedirect string) string {
|
||||
return tp.SignInURL
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded returns the mock provider's Refresh value, or an error.
|
||||
func (tp *TestProvider) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) {
|
||||
return tp.Refresh, tp.RefreshError
|
||||
}
|
||||
|
||||
// RefreshAccessToken returns the mock provider's refresh access token information
|
||||
func (tp *TestProvider) RefreshAccessToken(s string) (string, time.Duration, error) {
|
||||
return tp.RefreshFunc(s)
|
||||
}
|
||||
|
||||
// Revoke returns nil
|
||||
func (tp *TestProvider) Revoke(*sessions.SessionState) error {
|
||||
return tp.RevokeError
|
||||
}
|
||||
|
||||
// ValidateGroupMembership returns the mock provider's GroupsError if not nil, or the Groups field value.
|
||||
func (tp *TestProvider) ValidateGroupMembership(string, []string) ([]string, error) {
|
||||
return tp.Groups, tp.GroupsError
|
||||
}
|
||||
|
||||
// Redeem returns the mock provider's Session and RedeemError field value.
|
||||
func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||
return tp.Session, tp.RedeemError
|
||||
|
||||
}
|
||||
|
||||
// Stop fulfills the Provider interface
|
||||
func (tp *TestProvider) Stop() {
|
||||
return
|
||||
}
|
12
authorize/README.md
Normal file
12
authorize/README.md
Normal file
|
@ -0,0 +1,12 @@
|
|||
# Authorize
|
||||
|
||||
## What's this package do?
|
||||
The authorize packages makes a binary determination of access.
|
||||
|
||||
Authorization is on trust from:
|
||||
- Device state (vulnerability scanned?, MDM?, BYOD? Encrypted?)
|
||||
- User standing (HR status, Groups, etc)
|
||||
- Context (time, location, role)
|
||||
Driven by:
|
||||
- Dynamic "policy as code", fine grained policy
|
||||
- Machine Learning & anomaly detection based on multiple input sources
|
71
cmd/pomerium/main.go
Normal file
71
cmd/pomerium/main.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate"
|
||||
"github.com/pomerium/pomerium/internal/https"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/options"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"github.com/pomerium/pomerium/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
debugFlag = flag.Bool("debug", false, "run server in debug mode, changes log output to STDOUT and level to info")
|
||||
versionFlag = flag.Bool("version", false, "prints the version")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if *debugFlag {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout})
|
||||
}
|
||||
if *versionFlag {
|
||||
fmt.Printf("%s", version.FullVersion())
|
||||
os.Exit(0)
|
||||
}
|
||||
log.Info().Str("version", version.FullVersion()).Str("user-agent", version.UserAgent()).Msg("cmd/pomerium")
|
||||
authOpts, err := authenticate.OptionsFromEnvConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium : failed to parse authenticator settings")
|
||||
}
|
||||
emailValidator := func(p *authenticate.Authenticator) error {
|
||||
p.Validator = options.NewEmailValidator(authOpts.EmailDomains)
|
||||
return nil
|
||||
}
|
||||
|
||||
authenticator, err := authenticate.NewAuthenticator(authOpts, emailValidator)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium : failed to create authenticator")
|
||||
}
|
||||
|
||||
proxyOpts, err := proxy.OptionsFromEnvConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium : failed to parse proxy settings")
|
||||
}
|
||||
|
||||
validator := func(p *proxy.Proxy) error {
|
||||
p.EmailValidator = options.NewEmailValidator(proxyOpts.EmailDomains)
|
||||
return nil
|
||||
}
|
||||
|
||||
p, err := proxy.NewProxy(proxyOpts, validator)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium : failed to create proxy")
|
||||
}
|
||||
|
||||
// proxyHandler := log.NewLoggingHandler(p.Handler())
|
||||
authHandler := http.TimeoutHandler(authenticator.Handler(), authOpts.RequestTimeout, "")
|
||||
|
||||
topMux := http.NewServeMux()
|
||||
topMux.Handle(authOpts.Host+"/", authHandler)
|
||||
topMux.Handle("/", p.Handler())
|
||||
log.Fatal().Err(https.ListenAndServeTLS(nil, topMux))
|
||||
|
||||
}
|
BIN
docs/logo.png
Normal file
BIN
docs/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.8 KiB |
26
env.example
Normal file
26
env.example
Normal file
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
|
||||
export HOST="sso-auth.corp.beyondperimeter.com"
|
||||
export REDIRECT_URL="https://sso-auth.corp.beyondperimeter.com/oauth2/callback"
|
||||
export PROXY_ROOT_DOMAIN=beyondperimeter.com
|
||||
export PROXY_CLIENT_ID=WLgwUNIJW6DtsnAM2ck510znU2T3l+WufPg67e50oVM=
|
||||
export PROXY_CLIENT_SECRET=gFB0qsSxxPqCtoNMuF7Q1VupJSNEq0BguxlUfT0PE+Y=
|
||||
|
||||
# Generate 256 bitrandom key to encrypt the cookie `head -c32 /dev/urandom | base64`
|
||||
export AUTH_CODE_SECRET=9wiTZq4qvmS/plYQyvzGKWPlH/UBy0DMYMA2x/zngrM=
|
||||
export COOKIE_SECRET=uPGHo1ujND/k3B9V6yr52Gweq3RRYfFho98jxDG5Br8=
|
||||
export COOKIE_SECURE=true
|
||||
|
||||
# Valid email domains
|
||||
export EMAIL_DOMAIN=*
|
||||
export SSO_EMAIL_DOMAIN=*
|
||||
|
||||
# IdP configuration
|
||||
export IDP_PROVIDER="google"
|
||||
export IDP_PROVIDER_URL="https://sso-auth.corp.beyondperimeter.com"
|
||||
export IDP_CLIENT_ID="xxx.apps.googleusercontent.com"
|
||||
export IDP_CLIENT_SECRET="xxx"
|
||||
export IDP_REDIRECT_URL="https://sso-auth.corp.beyondperimeter.com/oauth2/callback"
|
||||
|
||||
# proxy'd routes
|
||||
export ROUTES='news.corp.beyondperimeter.com':'news.ycombinator.com','github.corp.beyondperimeter.com':'github.com'
|
19
go.mod
Normal file
19
go.mod
Normal file
|
@ -0,0 +1,19 @@
|
|||
module github.com/pomerium/pomerium
|
||||
|
||||
require (
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d
|
||||
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/rs/zerolog v1.11.0
|
||||
github.com/stretchr/testify v1.2.2 // indirect
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
|
||||
google.golang.org/appengine v1.4.0 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.2.1 // indirect
|
||||
)
|
34
go.sum
Normal file
34
go.sum
Normal file
|
@ -0,0 +1,34 @@
|
|||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b h1:VPhrxAgvd0d0xSZP4P8zzWZntso2NE0m69dmDEXM53I=
|
||||
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b/go.mod h1:Vj6lPE3LxPymcFxg7hm9aDIJWCyhJMnxSNC/y9ZHtN8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d h1:iAavGhspmmDa8LCcQLpHV/JBPWKL7EfvMN7YrGHBN3o=
|
||||
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d/go.mod h1:1Kz8Ca8PhJDtLYqgvbDZGn6GsJCvrT52SxQ3sPNJkDc=
|
||||
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
||||
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
|
||||
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis=
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic=
|
||||
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
172
internal/aead/aead.go
Normal file
172
internal/aead/aead.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
miscreant "github.com/miscreant/miscreant-go"
|
||||
)
|
||||
|
||||
const miscreantNonceSize = 16
|
||||
|
||||
var algorithmType = "AES-CMAC-SIV"
|
||||
|
||||
// Cipher provides methods to encrypt and decrypt values.
|
||||
type Cipher interface {
|
||||
Encrypt([]byte) ([]byte, error)
|
||||
Decrypt([]byte) ([]byte, error)
|
||||
Marshal(interface{}) (string, error)
|
||||
Unmarshal(string, interface{}) error
|
||||
}
|
||||
|
||||
// MiscreantCipher provides methods to encrypt and decrypt values.
|
||||
// Using an AEAD is a cipher providing authenticated encryption with associated data.
|
||||
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
|
||||
type MiscreantCipher struct {
|
||||
aead cipher.AEAD
|
||||
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// NewMiscreantCipher returns a new AES Cipher for encrypting values
|
||||
func NewMiscreantCipher(secret []byte) (*MiscreantCipher, error) {
|
||||
aead, err := miscreant.NewAEAD(algorithmType, secret, miscreantNonceSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MiscreantCipher{
|
||||
aead: aead,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateKey wraps miscreant's GenerateKey function
|
||||
func GenerateKey() []byte {
|
||||
return miscreant.GenerateKey(32)
|
||||
}
|
||||
|
||||
// Encrypt a value using AES-CMAC-SIV
|
||||
func (c *MiscreantCipher) Encrypt(plaintext []byte) (joined []byte, err error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("miscreant error encrypting bytes: %v", r)
|
||||
}
|
||||
}()
|
||||
nonce := miscreant.GenerateNonce(c.aead)
|
||||
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
// we return the nonce as part of the returned value
|
||||
joined = append(ciphertext[:], nonce[:]...)
|
||||
return joined, nil
|
||||
}
|
||||
|
||||
// Decrypt a value using AES-CMAC-SIV
|
||||
func (c *MiscreantCipher) Decrypt(joined []byte) ([]byte, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if len(joined) <= miscreantNonceSize {
|
||||
return nil, fmt.Errorf("invalid input size: %d", len(joined))
|
||||
}
|
||||
// grab out the nonce
|
||||
pivot := len(joined) - miscreantNonceSize
|
||||
ciphertext := joined[:pivot]
|
||||
nonce := joined[pivot:]
|
||||
|
||||
plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
|
||||
// and base64 encodes the binary value as a string and returns the result
|
||||
func (c *MiscreantCipher) Marshal(s interface{}) (string, error) {
|
||||
// encode json value
|
||||
plaintext, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
compressed, err := compress(plaintext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// encrypt the JSON
|
||||
ciphertext, err := c.Encrypt(compressed)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// base64-encode the result
|
||||
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice the pased cipher, and unmarshals the resulting JSON into the struct pointer passed
|
||||
func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error {
|
||||
// convert base64 string value to bytes
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// decrypt the bytes
|
||||
compressed, err := c.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// decompress
|
||||
plaintext, err := decompress(compressed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// unmarshal bytes
|
||||
err = json.Unmarshal(plaintext, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aead/compress: failed to create a gzip writer: %q", err)
|
||||
}
|
||||
if writer == nil {
|
||||
return nil, fmt.Errorf("aead/compress: failed to create a gzip writer")
|
||||
}
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("aead/compress: failed to compress data with err: %q", err)
|
||||
}
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func decompress(data []byte) ([]byte, error) {
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aead/compress: failed to create a gzip reader: %q", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
var buf bytes.Buffer
|
||||
if _, err = io.Copy(&buf, reader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
173
internal/aead/aead_test.go
Normal file
173
internal/aead/aead_test.go
Normal file
|
@ -0,0 +1,173 @@
|
|||
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||
plaintext := []byte("my plain text value")
|
||||
|
||||
key := GenerateKey()
|
||||
c, err := NewMiscreantCipher([]byte(key))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
ciphertext, err := c.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(plaintext, ciphertext) {
|
||||
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
|
||||
}
|
||||
|
||||
got, err := c.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err decrypting: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, plaintext) {
|
||||
t.Logf(" got: %v", got)
|
||||
t.Logf("want: %v", plaintext)
|
||||
t.Fatal("got unexpected decrypted value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||
key := GenerateKey()
|
||||
|
||||
c, err := NewMiscreantCipher([]byte(key))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
type TC struct {
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
tc := &TC{
|
||||
Field: "my plain text value",
|
||||
}
|
||||
|
||||
value1, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
value2, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if value1 == value2 {
|
||||
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
||||
}
|
||||
|
||||
got1 := &TC{}
|
||||
err = c.Unmarshal(value1, got1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tc) {
|
||||
t.Logf("want: %#v", tc)
|
||||
t.Logf(" got: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
|
||||
got2 := &TC{}
|
||||
err = c.Unmarshal(value2, got2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, got2) {
|
||||
t.Logf("got2: %#v", got2)
|
||||
t.Logf("got1: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCipherDataRace exercises a simple concurrency test for the MicscreantCipher.
|
||||
// In https://github.com/pomerium/pomerium/pull/75 we investigated why, on random occasion,
|
||||
// unmarshalling session states would fail, triggering users to get kicked out of
|
||||
// authenticated states. We narrowed our investigation to a data race we uncovered
|
||||
// from our misuse of the underlying miscreant library which makes no attempt
|
||||
// at thread-safety.
|
||||
//
|
||||
// In https://github.com/pomerium/pomerium/pull/77 we added this test to exercise the
|
||||
// data race condition and resolved said race by introducing a simple mutex.
|
||||
func TestCipherDataRace(t *testing.T) {
|
||||
miscreantCipher, err := NewMiscreantCipher(GenerateKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected generating cipher err: %v", err)
|
||||
}
|
||||
|
||||
type TC struct {
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(c *MiscreantCipher, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("unexecpted error reading random bytes: %v", err)
|
||||
}
|
||||
|
||||
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
|
||||
tc := &TC{
|
||||
Field: sha,
|
||||
}
|
||||
|
||||
value1, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
value2, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if value1 == value2 {
|
||||
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
||||
}
|
||||
|
||||
got1 := &TC{}
|
||||
err = c.Unmarshal(value1, got1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tc) {
|
||||
t.Logf("want: %#v", tc)
|
||||
t.Logf(" got: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
|
||||
got2 := &TC{}
|
||||
err = c.Unmarshal(value2, got2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, got2) {
|
||||
t.Logf("got2: %#v", got2)
|
||||
t.Logf("got1: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
|
||||
}(miscreantCipher, wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
34
internal/aead/mock_cipher.go
Normal file
34
internal/aead/mock_cipher.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MockCipher is a mock of the cipher interface
|
||||
type MockCipher struct {
|
||||
MarshalError error
|
||||
MarshalString string
|
||||
UnmarshalError error
|
||||
UnmarshalBytes []byte
|
||||
}
|
||||
|
||||
// Encrypt returns an empty byte array and nil
|
||||
func (mc *MockCipher) Encrypt([]byte) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
// Decrypt returns an empty byte array and nil
|
||||
func (mc *MockCipher) Decrypt([]byte) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
// Marshal returns the marshal string and marsha error
|
||||
func (mc *MockCipher) Marshal(interface{}) (string, error) {
|
||||
return mc.MarshalString, mc.MarshalError
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals the unmarshal bytes to be set in s and returns the unmarshal error
|
||||
func (mc *MockCipher) Unmarshal(b string, s interface{}) error {
|
||||
json.Unmarshal(mc.UnmarshalBytes, s)
|
||||
return mc.UnmarshalError
|
||||
}
|
36
internal/cryptutil/README.md
Normal file
36
internal/cryptutil/README.md
Normal file
|
@ -0,0 +1,36 @@
|
|||
|
||||
## Generating random seeds
|
||||
In order of preference:
|
||||
- `head -c32 /dev/urandom | base64`
|
||||
- `openssl rand -base64 32 | head -c 32 | base64`
|
||||
## Encrypting data
|
||||
|
||||
TL;DR -- Nonce reuse is a problem. AEAD isn't a clear choice right now.
|
||||
|
||||
[Miscreant](https://github.com/miscreant/miscreant.go)
|
||||
+ AES-GCM-SIV seems to have ideal properties
|
||||
+ random nonces
|
||||
- ~30% slower encryption
|
||||
- [not maintained by a BigCo](https://github.com/miscreant/miscreant.go/graphs/contributors)
|
||||
|
||||
[nacl/secretbot](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
|
||||
+ Fast
|
||||
+ XSalsa20 wutg Poly1305 MAC provides encryption and authentication together
|
||||
+ A newer standard and may not be considered acceptable in environments that require high levels of review.
|
||||
-/+ maintained as an [/x/ package](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
|
||||
- doesn't use the underlying cipher.AEAD api.
|
||||
|
||||
|
||||
GCM with random nonces
|
||||
+ Fastest
|
||||
+ Go standard library, supported by google $
|
||||
- Easy to get wrong
|
||||
- IV reuse is a known weakness so keys must be rotated before birthday attack. [NIST SP 800-38D](http://csrc.nist.gov/publications/nistpubs/800-38D/SP-800-38D.pdf) recommends using the same key with random 96-bit nonces (the default nonce length) no more than 2^32 times
|
||||
|
||||
Further reading on tradeoffs:
|
||||
- [Introducing Miscreant](https://tonyarcieri.com/introducing-miscreant-a-multi-language-misuse-resistant-encryption-library)
|
||||
- [agl's post AES-GCM-SIV](https://www.imperialviolet.org/2017/05/14/aesgcmsiv.html)
|
||||
- [x/crypto: add chacha20, xchacha20](https://github.com/golang/go/issues/24485s)
|
||||
- [GCM cannot be used with random nonces](https://github.com/gtank/cryptopasta/issues/14s)
|
||||
- [proposal: x/crypto/chacha20poly1305: add support for XChaCha20](https://github.com/golang/go/issues/23885)
|
||||
- [kubernetes](https://kubernetes.io/docs/tasks/administer-cluster/encrypt-data/#providers)
|
30
internal/cryptutil/hash.go
Normal file
30
internal/cryptutil/hash.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha512"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to
|
||||
// be a natural-language string describing the purpose of the hash, such as
|
||||
// "hash file for lookup key" or "master secret to client secret". It serves
|
||||
// as an HMAC "key" and ensures that different purposes will have different
|
||||
// hash output. This function is NOT suitable for hashing passwords.
|
||||
func Hash(tag string, data []byte) []byte {
|
||||
h := hmac.New(sha512.New512_256, []byte(tag))
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// HashPassword generates a bcrypt hash of the password using work factor 14.
|
||||
func HashPassword(password []byte) ([]byte, error) {
|
||||
return bcrypt.GenerateFromPassword(password, 14)
|
||||
}
|
||||
|
||||
// CheckPasswordHash securely compares a bcrypt hashed password with its possible
|
||||
// plaintext equivalent. Returns nil on success, or an error on failure.
|
||||
func CheckPasswordHash(hash, password []byte) error {
|
||||
return bcrypt.CompareHashAndPassword(hash, password)
|
||||
}
|
80
internal/cryptutil/hash_test.go
Normal file
80
internal/cryptutil/hash_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
bcryptTests := []struct {
|
||||
plaintext []byte
|
||||
hash []byte
|
||||
}{
|
||||
{
|
||||
plaintext: []byte("password"),
|
||||
hash: []byte("$2a$14$uALAQb/Lwl59oHVbuUa5m.xEFmQBc9ME/IiSgJK/VHtNJJXASCDoS"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range bcryptTests {
|
||||
hashed, err := HashPassword(tt.plaintext)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = CheckPasswordHash(hashed, tt.plaintext); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks SHA256 on 16K of random data.
|
||||
func BenchmarkSHA256(b *testing.B) {
|
||||
data, err := ioutil.ReadFile("testdata/random")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.SetBytes(int64(len(data)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = sha256.Sum256(data)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks SHA512/256 on 16K of random data.
|
||||
func BenchmarkSHA512_256(b *testing.B) {
|
||||
data, err := ioutil.ReadFile("testdata/random")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.SetBytes(int64(len(data)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = sha512.Sum512_256(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBcrypt(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := HashPassword([]byte("thisisareallybadpassword"))
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleHash() {
|
||||
tag := "hashing file for lookup key"
|
||||
contents, err := ioutil.ReadFile("testdata/random")
|
||||
if err != nil {
|
||||
fmt.Printf("could not read file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
digest := Hash(tag, contents)
|
||||
fmt.Println(hex.EncodeToString(digest))
|
||||
// Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1
|
||||
}
|
BIN
internal/cryptutil/testdata/random
vendored
Normal file
BIN
internal/cryptutil/testdata/random
vendored
Normal file
Binary file not shown.
32
internal/fileutil/fileutil.go
Normal file
32
internal/fileutil/fileutil.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
)
|
||||
|
||||
// IsReadableFile reports whether the file exists and is readable.
|
||||
// If the error is non-nil, it means there might be a file or directory
|
||||
// with that name but we cannot read it.
|
||||
//
|
||||
// Adapted from the upspin.io source code.
|
||||
func IsReadableFile(path string) (bool, error) {
|
||||
// Is it stattable and is it a plain file?
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil // Item does not exist.
|
||||
}
|
||||
return false, err // Item is problematic.
|
||||
}
|
||||
if info.IsDir() {
|
||||
return false, errors.New("is directory")
|
||||
}
|
||||
// Is it readable?
|
||||
fd, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, errors.New("permission denied")
|
||||
}
|
||||
fd.Close()
|
||||
return true, nil // Item exists and is readable.
|
||||
}
|
29
internal/fileutil/fileutil_test.go
Normal file
29
internal/fileutil/fileutil_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsReadableFile(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good file", "fileutil.go", true, false},
|
||||
{"file doesn't exist", "file-no-exist/nope", false, false},
|
||||
{"can't read dir", "./", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := IsReadableFile(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("IsReadableFile() error = %+v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("IsReadableFile() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
125
internal/https/https.go
Normal file
125
internal/https/https.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package https // import "github.com/pomerium/pomerium/internal/https"
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
)
|
||||
|
||||
// Options contains the configurations settings for a TLS http server
|
||||
type Options struct {
|
||||
// Addr specifies the host and port on which the server should serve
|
||||
// HTTPS requests. If empty, ":https" is used.
|
||||
Addr string
|
||||
|
||||
// CertFile and KeyFile specifies the TLS certificates to use.
|
||||
CertFile string
|
||||
KeyFile string
|
||||
}
|
||||
|
||||
var defaultOptions = &Options{
|
||||
Addr: ":https",
|
||||
CertFile: filepath.Join(findKeyDir(), "cert.pem"),
|
||||
KeyFile: filepath.Join(findKeyDir(), "privkey.pem"),
|
||||
}
|
||||
|
||||
func findKeyDir() string {
|
||||
p, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "."
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (opt *Options) applyDefaults() {
|
||||
if opt.Addr == "" {
|
||||
opt.Addr = defaultOptions.Addr
|
||||
}
|
||||
if opt.CertFile == "" {
|
||||
opt.CertFile = defaultOptions.CertFile
|
||||
}
|
||||
if opt.KeyFile == "" {
|
||||
opt.KeyFile = defaultOptions.KeyFile
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServeTLS serves the provided handlers by HTTPS
|
||||
// using the provided options.
|
||||
func ListenAndServeTLS(opt *Options, handler http.Handler) error {
|
||||
if opt == nil {
|
||||
opt = defaultOptions
|
||||
} else {
|
||||
opt.applyDefaults()
|
||||
}
|
||||
|
||||
config, err := newDefaultTLSConfig(opt.CertFile, opt.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("https: setting up TLS config: %v", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", opt.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln = tls.NewListener(ln, config)
|
||||
|
||||
// Set up the main server.
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
// WriteTimeout is set to 0 because it also pertains to
|
||||
// streaming replies, e.g., the DirServer.Watch interface.
|
||||
WriteTimeout: 0,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
TLSConfig: config,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
return server.Serve(ln)
|
||||
}
|
||||
|
||||
// newDefaultTLSConfig creates a new TLS config based on the certificate files given.
|
||||
func newDefaultTLSConfig(certFile string, certKeyFile string) (*tls.Config, error) {
|
||||
certReadable, err := fileutil.IsReadableFile(certFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("TLS certificate in %q: %q", certFile, err)
|
||||
}
|
||||
if !certReadable {
|
||||
return nil, fmt.Errorf("certificate file %q not readable", certFile)
|
||||
}
|
||||
keyReadable, err := fileutil.IsReadableFile(certKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("TLS key in %q: %v", certKeyFile, err)
|
||||
}
|
||||
if !keyReadable {
|
||||
return nil, fmt.Errorf("certificate key file %q not readable", certKeyFile)
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, certKeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
return tlsConfig, nil
|
||||
}
|
87
internal/httputil/client.go
Normal file
87
internal/httputil/client.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||
var ErrTokenRevoked = errors.New("Token expired or revoked")
|
||||
|
||||
var httpClient = &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
Transport: &http.Transport{
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 2 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// Client provides a simple helper interface to make HTTP requests
|
||||
func Client(method, endpoint, userAgent string, params url.Values, response interface{}) error {
|
||||
var body io.Reader
|
||||
switch method {
|
||||
case "POST":
|
||||
body = bytes.NewBufferString(params.Encode())
|
||||
case "GET":
|
||||
// error checking skipped because we are just parsing in
|
||||
// order to make a copy of an existing URL
|
||||
u, _ := url.Parse(endpoint)
|
||||
u.RawQuery = params.Encode()
|
||||
endpoint = u.String()
|
||||
default:
|
||||
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||
}
|
||||
req, err := http.NewRequest(method, endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
respBody, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusBadRequest:
|
||||
var response struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
}
|
||||
e := json.Unmarshal(respBody, &response)
|
||||
if e == nil && response.ErrorDescription == "Token expired or revoked" {
|
||||
return ErrTokenRevoked
|
||||
}
|
||||
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||
default:
|
||||
return fmt.Errorf(http.StatusText(resp.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
if response != nil {
|
||||
err := json.Unmarshal(respBody, &response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
82
internal/httputil/errors.go
Normal file
82
internal/httputil/errors.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUserNotAuthorized is an error for unauthorized users.
|
||||
ErrUserNotAuthorized = errors.New("user not authorized")
|
||||
)
|
||||
|
||||
// HTTPError stores the status code and a message for a given HTTP error.
|
||||
type HTTPError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error fulfills the error interface, returning a string representation of the error.
|
||||
func (h HTTPError) Error() string {
|
||||
return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message)
|
||||
}
|
||||
|
||||
// CodeForError maps an error type and returns a corresponding http.Status
|
||||
func CodeForError(err error) int {
|
||||
switch err {
|
||||
case ErrTokenRevoked:
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ErrorResponse renders an error page for errors given a message and a status code.
|
||||
func ErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
||||
if req.Header.Get("Accept") == "application/json" {
|
||||
var response struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
response.Error = message
|
||||
writeJSONResponse(rw, code, response)
|
||||
} else {
|
||||
title := http.StatusText(code)
|
||||
|
||||
log.Error().
|
||||
Int("http-status", code).
|
||||
Str("page-title", title).
|
||||
Str("page-message", message).
|
||||
Msg("authenticate/errors.ErrorResponse")
|
||||
|
||||
rw.WriteHeader(code)
|
||||
t := struct {
|
||||
Code int
|
||||
Title string
|
||||
Message string
|
||||
Version string
|
||||
}{
|
||||
Code: code,
|
||||
Title: title,
|
||||
Message: message,
|
||||
Version: version.FullVersion(),
|
||||
}
|
||||
templates.New().ExecuteTemplate(rw, "error.html", t)
|
||||
}
|
||||
}
|
||||
|
||||
// writeJSONResponse is a helper that sets the application/json header and writes a response.
|
||||
func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(code)
|
||||
|
||||
err := json.NewEncoder(rw).Encode(response)
|
||||
if err != nil {
|
||||
io.WriteString(rw, err.Error())
|
||||
}
|
||||
}
|
129
internal/log/log.go
Normal file
129
internal/log/log.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
// Package log provides a global logger for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Logger is the global logger.
|
||||
var Logger = zerolog.New(os.Stderr).With().Timestamp().Logger()
|
||||
|
||||
// Output duplicates the global logger and sets w as its output.
|
||||
func Output(w io.Writer) zerolog.Logger {
|
||||
return Logger.Output(w)
|
||||
}
|
||||
|
||||
// With creates a child logger with the field added to its context.
|
||||
func With() zerolog.Context {
|
||||
return Logger.With()
|
||||
}
|
||||
|
||||
// WithRequest creates a child logger with the remote user added to its context.
|
||||
func WithRequest(req *http.Request, function string) zerolog.Logger {
|
||||
remoteUser := getRemoteAddr(req)
|
||||
return Logger.With().
|
||||
Str("function", function).
|
||||
Str("req-remote-user", remoteUser).
|
||||
Str("req-http-method", req.Method).
|
||||
Str("req-host", req.Host).
|
||||
Str("req-url", req.URL.String()).
|
||||
Str("req-user-agent", req.Header.Get("User-Agent")).
|
||||
Logger()
|
||||
}
|
||||
|
||||
// Level creates a child logger with the minimum accepted level set to level.
|
||||
func Level(level zerolog.Level) zerolog.Logger {
|
||||
return Logger.Level(level)
|
||||
}
|
||||
|
||||
// Sample returns a logger with the s sampler.
|
||||
func Sample(s zerolog.Sampler) zerolog.Logger {
|
||||
return Logger.Sample(s)
|
||||
}
|
||||
|
||||
// Hook returns a logger with the h Hook.
|
||||
func Hook(h zerolog.Hook) zerolog.Logger {
|
||||
return Logger.Hook(h)
|
||||
}
|
||||
|
||||
// Debug starts a new message with debug level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Debug() *zerolog.Event {
|
||||
return Logger.Debug()
|
||||
}
|
||||
|
||||
// Info starts a new message with info level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Info() *zerolog.Event {
|
||||
return Logger.Info()
|
||||
}
|
||||
|
||||
// Warn starts a new message with warn level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Warn() *zerolog.Event {
|
||||
return Logger.Warn()
|
||||
}
|
||||
|
||||
// Error starts a new message with error level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Error() *zerolog.Event {
|
||||
return Logger.Error()
|
||||
}
|
||||
|
||||
// Fatal starts a new message with fatal level. The os.Exit(1) function
|
||||
// is called by the Msg method.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Fatal() *zerolog.Event {
|
||||
return Logger.Fatal()
|
||||
}
|
||||
|
||||
// Panic starts a new message with panic level. The message is also sent
|
||||
// to the panic function.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Panic() *zerolog.Event {
|
||||
return Logger.Panic()
|
||||
}
|
||||
|
||||
// WithLevel starts a new message with level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func WithLevel(level zerolog.Level) *zerolog.Event {
|
||||
return Logger.WithLevel(level)
|
||||
}
|
||||
|
||||
// Log starts a new message with no level. Setting zerolog.GlobalLevel to
|
||||
// zerolog.Disabled will still disable events produced by this method.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Log() *zerolog.Event {
|
||||
return Logger.Log()
|
||||
}
|
||||
|
||||
// Print sends a log event using debug level and no extra field.
|
||||
// Arguments are handled in the manner of fmt.Print.
|
||||
func Print(v ...interface{}) {
|
||||
Logger.Print(v...)
|
||||
}
|
||||
|
||||
// Printf sends a log event using debug level and no extra field.
|
||||
// Arguments are handled in the manner of fmt.Printf.
|
||||
func Printf(format string, v ...interface{}) {
|
||||
Logger.Printf(format, v...)
|
||||
}
|
||||
|
||||
// Ctx returns the Logger associated with the ctx. If no logger
|
||||
// is associated, a disabled logger is returned.
|
||||
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||
return zerolog.Ctx(ctx)
|
||||
}
|
145
internal/log/request_logger.go
Normal file
145
internal/log/request_logger.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Used to stash the authenticated user in the response for access when logging requests.
|
||||
const loggingUserHeader = "SSO-Authenticated-User"
|
||||
const gapMetaDataHeader = "GAP-Auth"
|
||||
|
||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||
// code and body size
|
||||
type responseLogger struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
proxyHost string
|
||||
authInfo string
|
||||
}
|
||||
|
||||
func (l *responseLogger) Header() http.Header {
|
||||
return l.w.Header()
|
||||
}
|
||||
|
||||
func (l *responseLogger) extractUser() {
|
||||
authInfo := l.w.Header().Get(loggingUserHeader)
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
l.w.Header().Del(loggingUserHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *responseLogger) ExtractGAPMetadata() {
|
||||
authInfo := l.w.Header().Get(gapMetaDataHeader)
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
|
||||
l.w.Header().Del(gapMetaDataHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *responseLogger) Write(b []byte) (int, error) {
|
||||
if l.status == 0 {
|
||||
// The status will be StatusOK if WriteHeader has not been called yet
|
||||
l.status = http.StatusOK
|
||||
}
|
||||
l.extractUser()
|
||||
l.ExtractGAPMetadata()
|
||||
|
||||
size, err := l.w.Write(b)
|
||||
l.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (l *responseLogger) WriteHeader(s int) {
|
||||
l.extractUser()
|
||||
l.ExtractGAPMetadata()
|
||||
|
||||
l.w.WriteHeader(s)
|
||||
l.status = s
|
||||
}
|
||||
|
||||
func (l *responseLogger) Status() int {
|
||||
return l.status
|
||||
}
|
||||
|
||||
func (l *responseLogger) Size() int {
|
||||
return l.size
|
||||
}
|
||||
|
||||
func (l *responseLogger) Flush() {
|
||||
f := l.w.(http.Flusher)
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends
|
||||
type loggingHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer.
|
||||
func NewLoggingHandler(h http.Handler) http.Handler {
|
||||
return loggingHandler{
|
||||
handler: h,
|
||||
}
|
||||
}
|
||||
|
||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
t := time.Now()
|
||||
url := *req.URL
|
||||
logger := &responseLogger{w: w, proxyHost: getProxyHost(req)}
|
||||
h.handler.ServeHTTP(logger, req)
|
||||
requestDuration := time.Since(t)
|
||||
|
||||
logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status())
|
||||
}
|
||||
|
||||
// logRequest logs information about a request
|
||||
func logRequest(proxyHost, username string, req *http.Request, url url.URL, requestDuration time.Duration, status int) {
|
||||
uri := req.Host + url.RequestURI()
|
||||
Info().
|
||||
Int("http-status", status).
|
||||
Str("request-method", req.Method).
|
||||
Str("request-uri", uri).
|
||||
Str("proxy-host", proxyHost).
|
||||
Str("user-agent", req.Header.Get("User-Agent")).
|
||||
Str("remote-address", getRemoteAddr(req)).
|
||||
Dur("duration", requestDuration).
|
||||
Str("user", username).
|
||||
Msg("request")
|
||||
|
||||
}
|
||||
|
||||
// getRemoteAddr returns the client IP address from a request. If present, the
|
||||
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||
func getRemoteAddr(req *http.Request) string {
|
||||
addr := req.RemoteAddr
|
||||
forwardedHeader := req.Header.Get("X-Forwarded-For")
|
||||
if forwardedHeader != "" {
|
||||
forwardedList := strings.Split(forwardedHeader, ",")
|
||||
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||
if forwardedAddr != "" {
|
||||
addr = forwardedAddr
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// getProxyHost attempts to get the proxy host from the redirect_uri parameter
|
||||
func getProxyHost(req *http.Request) string {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
redirect := req.Form.Get("redirect_uri")
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return redirectURL.Host
|
||||
}
|
72
internal/log/request_logger_test.go
Normal file
72
internal/log/request_logger_test.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetRemoteAddr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
forwardedHeader string
|
||||
expectedAddr string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is given",
|
||||
remoteAddr: "1.1.1.1",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is only whitespace",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " , , ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For header is preferred to RemoteAddr",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "9.9.9.9",
|
||||
expectedAddr: "9.9.9.9",
|
||||
},
|
||||
{
|
||||
name: "rightmost entry in X-Forwarded-For header is used",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5",
|
||||
expectedAddr: "5.5.5.5",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: "2.2.2.2, 3.3.3.3, ",
|
||||
expectedAddr: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwaded-For header entries are stripped",
|
||||
remoteAddr: "1.1.1.1",
|
||||
forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ",
|
||||
expectedAddr: "5.5.5.5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
if tc.forwardedHeader != "" {
|
||||
req.Header.Set("X-Forwarded-For", tc.forwardedHeader)
|
||||
}
|
||||
|
||||
addr := getRemoteAddr(req)
|
||||
if addr != tc.expectedAddr {
|
||||
t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
205
internal/middleware/middleware.go
Normal file
205
internal/middleware/middleware.go
Normal file
|
@ -0,0 +1,205 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// SetHeaders ensures that every response includes some basic security headers
|
||||
func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// WithMethods writes an error response if the method of the request is not included.
|
||||
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
||||
methodMap := make(map[string]struct{}, len(methods))
|
||||
for _, m := range methods {
|
||||
methodMap[m] = struct{}{}
|
||||
}
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := methodMap[req.Method]; !ok {
|
||||
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateClientID checks the request body or url for the client id and returns an error
|
||||
// if it does not match the proxy client id
|
||||
func ValidateClientID(f http.HandlerFunc, proxyClientID string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
// try to get the client id from the request body
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientID := req.FormValue("client_id")
|
||||
if clientID == "" {
|
||||
// try to get the clientID from the query parameters
|
||||
clientID = req.URL.Query().Get("client_id")
|
||||
}
|
||||
|
||||
if clientID != proxyClientID {
|
||||
httputil.ErrorResponse(rw, req, "Invalid client_id parameter", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateClientSecret checks the request header for the client secret and returns
|
||||
// an error if it does not match the proxy client secret
|
||||
func ValidateClientSecret(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
clientSecret := req.Form.Get("client_secret")
|
||||
// check the request header for the client secret
|
||||
if clientSecret == "" {
|
||||
clientSecret = req.Header.Get("X-Client-Secret")
|
||||
}
|
||||
if clientSecret != proxyClientSecret {
|
||||
log.Error().Str("clientSecret", clientSecret).Str("proxyClientSecret", proxyClientSecret).Msg("oh")
|
||||
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the url's domain is one in the list of proxy root domains.
|
||||
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||
redirectURL, err := url.Parse(uri)
|
||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
for _, domain := range rootDomains {
|
||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateSignature ensures the request is valid and has been signed with
|
||||
// the correspdoning client secret key
|
||||
func ValidateSignature(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
sigVal := req.Form.Get("sig")
|
||||
timestamp := req.Form.Get("ts")
|
||||
if !validSignature(redirectURI, sigVal, timestamp, proxyClientSecret) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateHost ensures that each request's host is valid
|
||||
func ValidateHost(h http.Handler, mux map[string]*http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := mux[req.Host]; !ok {
|
||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||
// todo(bdd) : this is unreliable unless behind another reverser proxy
|
||||
func RequireHTTPS(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
|
||||
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
|
||||
dest := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: req.Host,
|
||||
Path: req.URL.Path,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
}
|
||||
_, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
i, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
tm := time.Unix(i, 0)
|
||||
ttl := 5 * time.Minute
|
||||
if time.Since(tm) > ttl {
|
||||
return false
|
||||
}
|
||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||
return hmac.Equal(requestSig, localSig)
|
||||
}
|
||||
|
||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(rawRedirect))
|
||||
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||
return h.Sum(nil)
|
||||
}
|
35
internal/options/email_validator.go
Normal file
35
internal/options/email_validator.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package options // import "github.com/pomerium/pomerium/internal/options"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NewEmailValidator returns a function that checks whether a given email is valid based on a list
|
||||
// of domains. The domain "*" is a wild card that matches any non-empty email.
|
||||
func NewEmailValidator(domains []string) func(string) bool {
|
||||
allowAll := false
|
||||
for i, domain := range domains {
|
||||
if domain == "*" {
|
||||
allowAll = true
|
||||
}
|
||||
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
|
||||
}
|
||||
|
||||
if allowAll {
|
||||
return func(email string) bool { return email != "" }
|
||||
}
|
||||
|
||||
return func(email string) bool {
|
||||
if email == "" {
|
||||
return false
|
||||
}
|
||||
email = strings.ToLower(email)
|
||||
for _, domain := range domains {
|
||||
if strings.HasSuffix(email, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
121
internal/options/email_validator_test.go
Normal file
121
internal/options/email_validator_test.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package options // import "github.com/pomerium/pomerium/internal/options"
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEmailValidatorValidator(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
domains []string
|
||||
email string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "nothing should validate when domain list is empty",
|
||||
domains: []string(nil),
|
||||
email: "foo@example.com",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "single domain validation",
|
||||
domains: []string{"example.com"},
|
||||
email: "foo@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "substring matches are rejected",
|
||||
domains: []string{"example.com"},
|
||||
email: "foo@hackerexample.com",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "no subdomain rollup happens",
|
||||
domains: []string{"example.com"},
|
||||
email: "foo@bar.example.com",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "multiple domain validation still rejects other domains",
|
||||
domains: []string{"abc.com", "xyz.com"},
|
||||
email: "foo@example.com",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "multiple domain validation still accepts emails from either domain",
|
||||
domains: []string{"abc.com", "xyz.com"},
|
||||
email: "foo@abc.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "multiple domain validation still rejects other domains",
|
||||
domains: []string{"abc.com", "xyz.com"},
|
||||
email: "bar@xyz.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "comparisons are case insensitive",
|
||||
domains: []string{"Example.Com"},
|
||||
email: "foo@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "comparisons are case insensitive",
|
||||
domains: []string{"Example.Com"},
|
||||
email: "foo@EXAMPLE.COM",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "comparisons are case insensitive",
|
||||
domains: []string{"example.com"},
|
||||
email: "foo@ExAmPlE.CoM",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "single wildcard allows all",
|
||||
domains: []string{"*"},
|
||||
email: "foo@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "single wildcard allows all",
|
||||
domains: []string{"*"},
|
||||
email: "bar@gmail.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard in list allows all",
|
||||
domains: []string{"example.com", "*"},
|
||||
email: "foo@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard in list allows all",
|
||||
domains: []string{"example.com", "*"},
|
||||
email: "foo@gmail.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "empty email rejected",
|
||||
domains: []string{"example.com"},
|
||||
email: "",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard still rejects empty emails",
|
||||
domains: []string{"*"},
|
||||
email: "",
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
emailValidator := NewEmailValidator(tc.domains)
|
||||
valid := emailValidator(tc.email)
|
||||
if valid != tc.expectValid {
|
||||
t.Fatalf("expected %v, got %v", tc.expectValid, valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
163
internal/sessions/cookie_store.go
Normal file
163
internal/sessions/cookie_store.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// ErrInvalidSession is an error for invalid sessions.
|
||||
var ErrInvalidSession = errors.New("invalid session")
|
||||
|
||||
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
||||
type CSRFStore interface {
|
||||
SetCSRF(http.ResponseWriter, *http.Request, string)
|
||||
GetCSRF(*http.Request) (*http.Cookie, error)
|
||||
ClearCSRF(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
LoadSession(*http.Request) (*SessionState, error)
|
||||
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
|
||||
}
|
||||
|
||||
// CookieStore represents all the cookie related configurations
|
||||
type CookieStore struct {
|
||||
Name string
|
||||
CSRFCookieName string
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
CookieCipher aead.Cipher
|
||||
SessionLifetimeTTL time.Duration
|
||||
}
|
||||
|
||||
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
|
||||
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
|
||||
return func(s *CookieStore) error {
|
||||
cipher, err := aead.NewMiscreantCipher(cookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
|
||||
}
|
||||
s.CookieCipher = cipher
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
|
||||
c := &CookieStore{
|
||||
Name: cookieName,
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: 168 * time.Hour,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
|
||||
}
|
||||
|
||||
for _, f := range optFuncs {
|
||||
err := f(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
domain := c.CookieDomain
|
||||
if domain == "" {
|
||||
domain = "<default>"
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
domain := req.Host
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
domain = h
|
||||
}
|
||||
if s.CookieDomain != "" {
|
||||
if !strings.HasSuffix(domain, s.CookieDomain) {
|
||||
log.Warn().Str("cookie-domain", s.CookieDomain).Msg("using configured cookie domain")
|
||||
}
|
||||
domain = s.CookieDomain
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: domain,
|
||||
HttpOnly: s.CookieHTTPOnly,
|
||||
Secure: s.CookieSecure,
|
||||
Expires: now.Add(expiration),
|
||||
}
|
||||
}
|
||||
|
||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return s.makeCookie(req, s.Name, value, expiration, now)
|
||||
}
|
||||
|
||||
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time.
|
||||
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return s.makeCookie(req, s.CSRFCookieName, value, expiration, now)
|
||||
}
|
||||
|
||||
// ClearCSRF clears the CSRF cookie from the request
|
||||
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
||||
return req.Cookie(s.CSRFCookieName)
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// LoadSession returns a SessionState from the cookie in the request.
|
||||
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
||||
c, err := req.Cookie(s.Name)
|
||||
if err != nil {
|
||||
// always http.ErrNoCookie
|
||||
return nil, err
|
||||
}
|
||||
session, err := UnmarshalSession(c.Value, s.CookieCipher)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("remote-host", req.Host).Msg("error unmarshaling session")
|
||||
return nil, ErrInvalidSession
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setSessionCookie(rw, req, value)
|
||||
return nil
|
||||
}
|
348
internal/sessions/cookie_store_test.go
Normal file
348
internal/sessions/cookie_store_test.go
Normal file
|
@ -0,0 +1,348 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
|
||||
|
||||
func TestCreateMiscreantCookieCipher(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
cookieSecret []byte
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "normal case with base64 encoded secret",
|
||||
cookieSecret: testEncodedCookieSecret,
|
||||
},
|
||||
|
||||
{
|
||||
name: "error when not base64 encoded",
|
||||
cookieSecret: []byte("abcd"),
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewCookieStore("cookieName", CreateMiscreantCookieCipher(tc.cookieSecret))
|
||||
if !tc.expectedError {
|
||||
testutil.Ok(t, err)
|
||||
} else {
|
||||
testutil.NotEqual(t, err, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSession(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedError bool
|
||||
expectedSession *CookieStore
|
||||
}{
|
||||
{
|
||||
name: "default with no opt funcs set",
|
||||
expectedSession: &CookieStore{
|
||||
Name: "cookieName",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: 168 * time.Hour,
|
||||
CSRFCookieName: "cookieName_csrf",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opt func with an error returns an error",
|
||||
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "opt func overrides default values",
|
||||
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
|
||||
s.CookieExpire = time.Hour
|
||||
return nil
|
||||
}},
|
||||
expectedSession: &CookieStore{
|
||||
Name: "cookieName",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Hour,
|
||||
CSRFCookieName: "cookieName_csrf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore("cookieName", tc.optFuncs...)
|
||||
if tc.expectedError {
|
||||
testutil.NotEqual(t, err, nil)
|
||||
} else {
|
||||
testutil.Ok(t, err)
|
||||
}
|
||||
testutil.Equal(t, tc.expectedSession, session)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeSessionCookie(t *testing.T) {
|
||||
now := time.Now()
|
||||
cookieValue := "cookieValue"
|
||||
expiration := time.Hour
|
||||
cookieName := "cookieName"
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedCookie *http.Cookie
|
||||
}{
|
||||
{
|
||||
name: "default cookie domain",
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "www.example.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom cookie domain set",
|
||||
optFuncs: []func(*CookieStore) error{
|
||||
func(s *CookieStore) error {
|
||||
s.CookieDomain = "buzzfeed.com"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "buzzfeed.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
|
||||
testutil.Equal(t, cookie, tc.expectedCookie)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeSessionCSRFCookie(t *testing.T) {
|
||||
now := time.Now()
|
||||
cookieValue := "cookieValue"
|
||||
expiration := time.Hour
|
||||
cookieName := "cookieName"
|
||||
csrfName := "cookieName_csrf"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedCookie *http.Cookie
|
||||
}{
|
||||
{
|
||||
name: "default cookie domain",
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: csrfName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "www.example.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom cookie domain set",
|
||||
optFuncs: []func(*CookieStore) error{
|
||||
func(s *CookieStore) error {
|
||||
s.CookieDomain = "buzzfeed.com"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: csrfName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "buzzfeed.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
|
||||
testutil.Equal(t, tc.expectedCookie, cookie)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set session cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session.setSessionCookie(rw, req, cookieValue)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == cookieName {
|
||||
found = true
|
||||
testutil.Equal(t, cookieValue, cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
func TestSetCSRFSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set csrf cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session.SetCSRF(rw, req, cookieValue)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||
found = true
|
||||
testutil.Equal(t, cookieValue, cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set session cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
session.ClearSession(rw, req)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == cookieName {
|
||||
found = true
|
||||
testutil.Equal(t, "", cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearCSRFSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("clear csrf cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
session.ClearCSRF(rw, req)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||
found = true
|
||||
testutil.Equal(t, "", cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
cookieName := "cookieName"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
|
||||
expectedError error
|
||||
sessionState *SessionState
|
||||
}{
|
||||
{
|
||||
name: "no cookie set returns an error",
|
||||
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
|
||||
expectedError: http.ErrNoCookie,
|
||||
},
|
||||
{
|
||||
name: "cookie set with cipher set",
|
||||
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
|
||||
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
testutil.Ok(t, err)
|
||||
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||
},
|
||||
sessionState: &SessionState{
|
||||
Email: "example@email.com",
|
||||
RefreshToken: "abccdddd",
|
||||
AccessToken: "access",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cookie set with invalid value cipher set",
|
||||
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
|
||||
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||
value := "574b776a7c934d6b9fc42ec63a389f79"
|
||||
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||
},
|
||||
expectedError: ErrInvalidSession,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "https://www.example.com", nil)
|
||||
tc.setupCookies(t, req, session, tc.sessionState)
|
||||
s, err := session.LoadSession(req)
|
||||
|
||||
testutil.Equal(t, tc.expectedError, err)
|
||||
testutil.Equal(t, tc.sessionState, s)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
50
internal/sessions/mock_store.go
Normal file
50
internal/sessions/mock_store.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// MockCSRFStore is a mock implementation of the CSRF store interface
|
||||
type MockCSRFStore struct {
|
||||
ResponseCSRF string
|
||||
Cookie *http.Cookie
|
||||
GetError error
|
||||
}
|
||||
|
||||
// SetCSRF sets the ResponseCSRF string to a val
|
||||
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
ms.ResponseCSRF = val
|
||||
}
|
||||
|
||||
// ClearCSRF clears the ResponseCSRF string
|
||||
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseCSRF = ""
|
||||
}
|
||||
|
||||
// GetCSRF returns the cookie and error
|
||||
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||
return ms.Cookie, ms.GetError
|
||||
}
|
||||
|
||||
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||
type MockSessionStore struct {
|
||||
ResponseSession string
|
||||
Session *SessionState
|
||||
SaveError error
|
||||
LoadError error
|
||||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||
return ms.SaveError
|
||||
}
|
70
internal/sessions/session_state.go
Normal file
70
internal/sessions/session_state.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrLifetimeExpired is an error for the lifetime deadline expiring
|
||||
ErrLifetimeExpired = errors.New("user lifetime expired")
|
||||
)
|
||||
|
||||
// SessionState is our object that keeps track of a user's session state
|
||||
type SessionState struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"` // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||
|
||||
RefreshDeadline time.Time `json:"refresh_deadline"`
|
||||
LifetimeDeadline time.Time `json:"lifetime_deadline"`
|
||||
ValidDeadline time.Time `json:"valid_deadline"`
|
||||
GracePeriodStart time.Time `json:"grace_period_start"`
|
||||
|
||||
Email string `json:"email"`
|
||||
User string `json:"user"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
// LifetimePeriodExpired returns true if the lifetime has expired
|
||||
func (s *SessionState) LifetimePeriodExpired() bool {
|
||||
return isExpired(s.LifetimeDeadline)
|
||||
}
|
||||
|
||||
// RefreshPeriodExpired returns true if the refresh period has expired
|
||||
func (s *SessionState) RefreshPeriodExpired() bool {
|
||||
return isExpired(s.RefreshDeadline)
|
||||
}
|
||||
|
||||
// ValidationPeriodExpired returns true if the validation period has expired
|
||||
func (s *SessionState) ValidationPeriodExpired() bool {
|
||||
return isExpired(s.ValidDeadline)
|
||||
}
|
||||
|
||||
func isExpired(t time.Time) bool {
|
||||
return t.Before(time.Now())
|
||||
}
|
||||
|
||||
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||
// given cipher, and base64-encodes the result
|
||||
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) {
|
||||
return c.Marshal(s)
|
||||
}
|
||||
|
||||
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct
|
||||
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
|
||||
s := &SessionState{}
|
||||
err := c.Unmarshal(value, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ExtendDeadline returns the time extended by a given duration
|
||||
func ExtendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
71
internal/sessions/session_state_test.go
Normal file
71
internal/sessions/session_state_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
)
|
||||
|
||||
func TestSessionStateSerialization(t *testing.T) {
|
||||
secret := aead.GenerateKey()
|
||||
c, err := aead.NewMiscreantCipher([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
|
||||
want := &SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
|
||||
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
ciphertext, err := MarshalSession(want, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be encode session: %v", err)
|
||||
}
|
||||
|
||||
got, err := UnmarshalSession(ciphertext, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be decode session: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Logf("want: %#v", want)
|
||||
t.Logf(" got: %#v", got)
|
||||
t.Errorf("encoding and decoding session resulted in unexpected output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStateExpirations(t *testing.T) {
|
||||
session := &SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
|
||||
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
|
||||
RefreshDeadline: time.Now().Add(-1 * time.Hour),
|
||||
ValidDeadline: time.Now().Add(-1 * time.Minute),
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
if !session.LifetimePeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
|
||||
if !session.RefreshPeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
|
||||
if !session.ValidationPeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
}
|
75
internal/singleflight/singleflight.go
Normal file
75
internal/singleflight/singleflight.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
// Original Copyright 2013 The Go Authors. All rights reserved.
|
||||
//
|
||||
// Modified by BuzzFeed to return duplicate counts.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package singleflight provides a duplicate function call suppression mechanism.
|
||||
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
|
||||
|
||||
import "sync"
|
||||
|
||||
// call is an in-flight or completed singleflight.Do call
|
||||
type call struct {
|
||||
wg sync.WaitGroup
|
||||
|
||||
// These fields are written once before the WaitGroup is done
|
||||
// and are only read after the WaitGroup is done.
|
||||
val interface{}
|
||||
err error
|
||||
|
||||
// These fields are read and written with the singleflight
|
||||
// mutex held before the WaitGroup is done, and are read but
|
||||
// not written after the WaitGroup is done.
|
||||
dups int
|
||||
}
|
||||
|
||||
// Group represents a class of work and forms a namespace in
|
||||
// which units of work can be executed with duplicate suppression.
|
||||
type Group struct {
|
||||
mu sync.Mutex // protects m
|
||||
m map[string]*call // lazily initialized
|
||||
}
|
||||
|
||||
// Result holds the results of Do, so they can be passed
|
||||
// on a channel.
|
||||
type Result struct {
|
||||
Val interface{}
|
||||
Err error
|
||||
Count bool
|
||||
}
|
||||
|
||||
// Do executes and returns the results of the given function, making
|
||||
// sure that only one execution is in-flight for a given key at a
|
||||
// time. If a duplicate comes in, the duplicate caller waits for the
|
||||
// original to complete and receives the same results.
|
||||
// The return value of Count indicates how many tiems v was given to multiple callers.
|
||||
// Count will be zero for requests are shared and only be non-zero for the originating request.
|
||||
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
|
||||
g.mu.Lock()
|
||||
|
||||
if g.m == nil {
|
||||
g.m = make(map[string]*call)
|
||||
}
|
||||
if c, ok := g.m[key]; ok {
|
||||
c.dups++
|
||||
g.mu.Unlock()
|
||||
c.wg.Wait()
|
||||
return c.val, 0, c.err
|
||||
}
|
||||
c := new(call)
|
||||
c.wg.Add(1)
|
||||
g.m[key] = c
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
c.val, c.err = fn()
|
||||
c.wg.Done()
|
||||
|
||||
g.mu.Lock()
|
||||
delete(g.m, key)
|
||||
g.mu.Unlock()
|
||||
|
||||
return c.val, c.dups, c.err
|
||||
}
|
87
internal/singleflight/singleflight_test.go
Normal file
87
internal/singleflight/singleflight_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
var g Group
|
||||
v, _, err := g.Do("key", func() (interface{}, error) {
|
||||
return "bar", nil
|
||||
})
|
||||
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
|
||||
t.Errorf("Do = %v; want %v", got, want)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Do error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoErr(t *testing.T) {
|
||||
var g Group
|
||||
someErr := errors.New("Some error")
|
||||
v, _, err := g.Do("key", func() (interface{}, error) {
|
||||
return nil, someErr
|
||||
})
|
||||
if err != someErr {
|
||||
t.Errorf("Do error = %v; want someErr %v", err, someErr)
|
||||
}
|
||||
if v != nil {
|
||||
t.Errorf("unexpected non-nil value %#v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoDupSuppress(t *testing.T) {
|
||||
var g Group
|
||||
var wg1, wg2 sync.WaitGroup
|
||||
c := make(chan string, 1)
|
||||
var calls int32
|
||||
fn := func() (interface{}, error) {
|
||||
if atomic.AddInt32(&calls, 1) == 1 {
|
||||
// First invocation.
|
||||
wg1.Done()
|
||||
}
|
||||
v := <-c
|
||||
c <- v // pump; make available for any future calls
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
const n = 10
|
||||
wg1.Add(1)
|
||||
for i := 0; i < n; i++ {
|
||||
wg1.Add(1)
|
||||
wg2.Add(1)
|
||||
go func() {
|
||||
defer wg2.Done()
|
||||
wg1.Done()
|
||||
v, _, err := g.Do("key", fn)
|
||||
if err != nil {
|
||||
t.Errorf("Do error: %v", err)
|
||||
return
|
||||
}
|
||||
if s, _ := v.(string); s != "bar" {
|
||||
t.Errorf("Do = %T %v; want %q", v, v, "bar")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg1.Wait()
|
||||
// At least one goroutine is in fn now and all of them have at
|
||||
// least reached the line before the Do.
|
||||
c <- "bar"
|
||||
wg2.Wait()
|
||||
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
|
||||
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
|
||||
}
|
||||
}
|
199
internal/templates/templates.go
Normal file
199
internal/templates/templates.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
package templates // import "github.com/pomerium/pomerium/internal/templates"
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
)
|
||||
|
||||
// New loads html and style resources directly. Panics on failure.
|
||||
func New() *template.Template {
|
||||
t := template.New("authenticate-templates")
|
||||
template.Must(t.Parse(`
|
||||
{{define "header.html"}}
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: "Helvetica Neue",Helvetica,Arial,sans-serif;
|
||||
font-size: 1em;
|
||||
line-height: 1.42857143;
|
||||
color: #333;
|
||||
background: #f0f0f0;
|
||||
}
|
||||
p {
|
||||
margin: 1.5em 0;
|
||||
}
|
||||
p:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
p:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
.container {
|
||||
max-width: 40em;
|
||||
display: block;
|
||||
margin: 10% auto;
|
||||
text-align: center;
|
||||
}
|
||||
.content, .message, button {
|
||||
border: 1px solid rgba(0,0,0,.125);
|
||||
border-bottom-width: 4px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.content, .message {
|
||||
background-color: #fff;
|
||||
padding: 2rem;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
.error, .message {
|
||||
border-bottom-color: #c00;
|
||||
}
|
||||
.message {
|
||||
padding: 1.5rem 2rem 1.3rem;
|
||||
}
|
||||
header {
|
||||
border-bottom: 1px solid rgba(0,0,0,.075);
|
||||
margin: -2rem 0 2rem;
|
||||
padding: 2rem 0 1.8rem;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 1.5em;
|
||||
font-weight: normal;
|
||||
}
|
||||
.error header {
|
||||
color: #c00;
|
||||
}
|
||||
.details {
|
||||
font-size: .85rem;
|
||||
color: #999;
|
||||
}
|
||||
button {
|
||||
color: #fff;
|
||||
background-color: #3B8686;
|
||||
cursor: pointer;
|
||||
font-size: 1.5rem;
|
||||
font-weight: bold;
|
||||
padding: 1rem 2.5rem;
|
||||
text-shadow: 0 3px 1px rgba(0,0,0,.2);
|
||||
outline: none;
|
||||
}
|
||||
button:active {
|
||||
border-top-width: 4px;
|
||||
border-bottom-width: 1px;
|
||||
text-shadow: none;
|
||||
}
|
||||
footer {
|
||||
font-size: 0.75em;
|
||||
color: #999;
|
||||
text-align: right;
|
||||
margin: 1rem;
|
||||
}
|
||||
</style>
|
||||
{{end}}`))
|
||||
|
||||
t = template.Must(t.Parse(`{{define "footer.html"}}Secured by <b>pomerium</b> {{end}}`))
|
||||
|
||||
t = template.Must(t.Parse(`
|
||||
{{define "sign_in_message.html"}}
|
||||
{{if eq (len .EmailDomains) 1}}
|
||||
{{if eq (index .EmailDomains 0) "@*"}}
|
||||
<p>You may sign in with any {{.ProviderName}} account.</p>
|
||||
{{else}}
|
||||
<p>You may sign in with your <b>{{index .EmailDomains 0}}</b> {{.ProviderName}} account.</p>
|
||||
{{end}}
|
||||
{{else if gt (len .EmailDomains) 1}}
|
||||
<p>
|
||||
You may sign in with any of these {{.ProviderName}} accounts:<br>
|
||||
{{range $i, $e := .EmailDomains}}{{if $i}}, {{end}}<b>{{$e}}</b>{{end}}
|
||||
</p>
|
||||
{{end}}
|
||||
{{end}}`))
|
||||
|
||||
t = template.Must(t.Parse(`
|
||||
{{define "sign_in.html"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" charset="utf-8">
|
||||
<head>
|
||||
<title>Sign In</title>
|
||||
{{template "header.html"}}
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="content">
|
||||
<header>
|
||||
<h1>Sign in to <b>{{.Destination}}</b></h1>
|
||||
</header>
|
||||
|
||||
{{template "sign_in_message.html" .}}
|
||||
|
||||
<form method="GET" action="/start">
|
||||
<input type="hidden" name="redirect_uri" value="{{.Redirect}}">
|
||||
<button type="submit" class="btn">Sign in with {{.ProviderName}}</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<footer>{{template "footer.html"}} </br> {{.Version}} </footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}`))
|
||||
|
||||
template.Must(t.Parse(`
|
||||
{{define "error.html"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" charset="utf-8">
|
||||
<head>
|
||||
<title>Error</title>
|
||||
{{template "header.html"}}
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="content error">
|
||||
<header>
|
||||
<h1>{{.Title}}</h1>
|
||||
</header>
|
||||
<p>
|
||||
{{.Message}}<br>
|
||||
<span class="details">HTTP {{.Code}}</span>
|
||||
</p>
|
||||
</div>
|
||||
<footer>{{template "footer.html"}} </br> {{.Version}} </footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>{{end}}`))
|
||||
|
||||
t = template.Must(t.Parse(`
|
||||
{{define "sign_out.html"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" charset="utf-8">
|
||||
<head>
|
||||
<title>Sign Out</title>
|
||||
{{template "header.html"}}
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
{{ if .Message }}
|
||||
<div class="message">{{.Message}}</div>
|
||||
{{ end}}
|
||||
<div class="content">
|
||||
<header>
|
||||
<h1>Sign out of <b>{{.Destination}}</b></h1>
|
||||
</header>
|
||||
|
||||
<p>You're currently signed in as <b>{{.Email}}</b>. This will also sign you out of other internal apps.</p>
|
||||
<form method="POST" action="/sign_out">
|
||||
<input type="hidden" name="redirect_uri" value="{{.Redirect}}">
|
||||
<input type="hidden" name="sig" value="{{.Signature}}">
|
||||
<input type="hidden" name="ts" value="{{.Timestamp}}">
|
||||
<button type="submit">Sign out</button>
|
||||
</form>
|
||||
</div>
|
||||
<footer>{{template "footer.html"}} </br> {{.Version}}</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}`))
|
||||
return t
|
||||
}
|
12
internal/templates/templates_test.go
Normal file
12
internal/templates/templates_test.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package templates // import "github.com/pomerium/pomerium/internal/templates"
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
func TestTemplatesCompile(t *testing.T) {
|
||||
templates := New()
|
||||
testutil.NotEqual(t, templates, nil)
|
||||
}
|
46
internal/testutil/testutil.go
Normal file
46
internal/testutil/testutil.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
|
||||
|
||||
// testing util functions copied from https://github.com/benbjohnson/testing
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Assert fails the test if the condition is false.
|
||||
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
|
||||
if !condition {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Ok fails the test if an err is not nil.
|
||||
func Ok(tb testing.TB, err error) {
|
||||
if err != nil {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Equal fails the test if exp is not equal to act.
|
||||
func Equal(tb testing.TB, exp, act interface{}) {
|
||||
if !reflect.DeepEqual(exp, act) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// NotEqual fails the test if exp is equal to act.
|
||||
func NotEqual(tb testing.TB, exp, act interface{}) {
|
||||
if reflect.DeepEqual(exp, act) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
42
internal/version/version.go
Normal file
42
internal/version/version.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package version // import "github.com/pomerium/pomerium/internal/version"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// ProjectName is the canonical project name set by ldl flags
|
||||
ProjectName = ""
|
||||
// ProjectURL is the canonical project url set by ldl flags
|
||||
ProjectURL = ""
|
||||
// Version specifies Semantic versioning increment (MAJOR.MINOR.PATCH).
|
||||
Version = "v0.0.0"
|
||||
// GitCommit specifies the git commit sha, set by the compiler.
|
||||
GitCommit = ""
|
||||
// BuildMeta specifies release type (dev,rc1,beta,etc)
|
||||
BuildMeta = ""
|
||||
|
||||
runtimeVersion = runtime.Version()
|
||||
)
|
||||
|
||||
// FullVersion returns a version string.
|
||||
func FullVersion() string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(Version) + len(GitCommit) + len(BuildMeta) + len("-") + len("+"))
|
||||
sb.WriteString(Version)
|
||||
if BuildMeta != "" {
|
||||
sb.WriteString("-" + BuildMeta)
|
||||
}
|
||||
if GitCommit != "" {
|
||||
sb.WriteString("+" + GitCommit)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// UserAgent returns a user-agent string as specified in RFC 2616:14.43
|
||||
// https://tools.ietf.org/html/rfc2616
|
||||
func UserAgent() string {
|
||||
return fmt.Sprintf("%s/%s (+%s; %s; %s)", ProjectName, Version, ProjectURL, GitCommit, runtimeVersion)
|
||||
}
|
71
internal/version/version_test.go
Normal file
71
internal/version/version_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package version // import "github.com/pomerium/pomerium/internal/version"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFullVersionVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
Version string
|
||||
GitCommit string
|
||||
BuildMeta string
|
||||
|
||||
expected string
|
||||
}{
|
||||
{"", "", "", ""},
|
||||
{"1.0.0", "", "", "1.0.0"},
|
||||
{"1.0.0", "314501b", "", "1.0.0+314501b"},
|
||||
{"1.0.0", "314501b", "dev", "1.0.0-dev+314501b"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
Version = tt.Version
|
||||
GitCommit = tt.GitCommit
|
||||
BuildMeta = tt.BuildMeta
|
||||
|
||||
if got := FullVersion(); got != tt.expected {
|
||||
t.Errorf("expected (%s) got (%s) for Version(%s), GitCommit(%s) BuildMeta(%s)",
|
||||
tt.expected,
|
||||
got,
|
||||
tt.Version,
|
||||
tt.GitCommit,
|
||||
tt.BuildMeta)
|
||||
}
|
||||
}
|
||||
}
|
||||
func BenchmarkFullVersion(b *testing.B) {
|
||||
Version = "1.0.0"
|
||||
GitCommit = "314501b"
|
||||
BuildMeta = "dev"
|
||||
for i := 0; i < b.N; i++ {
|
||||
FullVersion()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAgent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
Version string
|
||||
GitCommit string
|
||||
BuildMeta string
|
||||
ProjectName string
|
||||
ProjectURL string
|
||||
want string
|
||||
}{
|
||||
{"good user agent", "1.0.0", "314501b", "dev", "pomerium", "github.com/pomerium", fmt.Sprintf("pomerium/1.0.0 (+github.com/pomerium; 314501b; %s)", runtime.Version())},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
Version = tt.Version
|
||||
GitCommit = tt.GitCommit
|
||||
BuildMeta = tt.BuildMeta
|
||||
ProjectName = tt.ProjectName
|
||||
ProjectURL = tt.ProjectURL
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := UserAgent(); got != tt.want {
|
||||
t.Errorf("UserAgent() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
317
proxy/authenticator/authenticator.go
Normal file
317
proxy/authenticator/authenticator.go
Normal file
|
@ -0,0 +1,317 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
var defaultHTTPClient = &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
Transport: &http.Transport{
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 2 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrMissingRefreshToken = errors.New("missing refresh token")
|
||||
ErrAuthProviderUnavailable = errors.New("auth provider unavailable")
|
||||
)
|
||||
|
||||
// AuthenticateClient holds the data associated with the AuthenticateClients
|
||||
// necessary to implement a AuthenticateClient interface.
|
||||
type AuthenticateClient struct {
|
||||
AuthenticateServiceURL *url.URL
|
||||
|
||||
//
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
SignInURL *url.URL
|
||||
SignOutURL *url.URL
|
||||
RedeemURL *url.URL
|
||||
RefreshURL *url.URL
|
||||
ProfileURL *url.URL
|
||||
ValidateURL *url.URL
|
||||
|
||||
SessionValidTTL time.Duration
|
||||
SessionLifetimeTTL time.Duration
|
||||
GracePeriodTTL time.Duration
|
||||
}
|
||||
|
||||
// NewAuthenticateClient instantiates a new AuthenticateClient with provider data
|
||||
func NewAuthenticateClient(uri *url.URL, clientID, clientSecret string, sessionValid, sessionLifetime, gracePeriod time.Duration) *AuthenticateClient {
|
||||
return &AuthenticateClient{
|
||||
AuthenticateServiceURL: uri,
|
||||
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
|
||||
SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}),
|
||||
SignOutURL: uri.ResolveReference(&url.URL{Path: "/sign_out"}),
|
||||
RedeemURL: uri.ResolveReference(&url.URL{Path: "/redeem"}),
|
||||
RefreshURL: uri.ResolveReference(&url.URL{Path: "/refresh"}),
|
||||
ValidateURL: uri.ResolveReference(&url.URL{Path: "/validate"}),
|
||||
ProfileURL: uri.ResolveReference(&url.URL{Path: "/profile"}),
|
||||
|
||||
SessionValidTTL: sessionValid,
|
||||
SessionLifetimeTTL: sessionLifetime,
|
||||
GracePeriodTTL: gracePeriod,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *AuthenticateClient) newRequest(method, url string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", version.UserAgent())
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Host = p.AuthenticateServiceURL.Host
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func isProviderUnavailable(statusCode int) bool {
|
||||
return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
func extendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
||||
|
||||
func (p *AuthenticateClient) withinGracePeriod(s *sessions.SessionState) bool {
|
||||
if s.GracePeriodStart.IsZero() {
|
||||
s.GracePeriodStart = time.Now()
|
||||
}
|
||||
return s.GracePeriodStart.Add(p.GracePeriodTTL).After(time.Now())
|
||||
}
|
||||
|
||||
// Redeem takes a redirectURL and code and redeems the SessionState
|
||||
func (p *AuthenticateClient) Redeem(redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
// params that are validates by the authenticate service middleware
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("code", code)
|
||||
|
||||
req, err := p.newRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := defaultHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
if isProviderUnavailable(resp.StatusCode) {
|
||||
return nil, ErrAuthProviderUnavailable
|
||||
}
|
||||
return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := strings.Split(jsonResponse.Email, "@")[0]
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
|
||||
RefreshDeadline: extendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second),
|
||||
LifetimeDeadline: extendDeadline(p.SessionLifetimeTTL),
|
||||
ValidDeadline: extendDeadline(p.SessionValidTTL),
|
||||
|
||||
Email: jsonResponse.Email,
|
||||
User: user,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshSession refreshes the current session
|
||||
func (p *AuthenticateClient) RefreshSession(s *sessions.SessionState) (bool, error) {
|
||||
|
||||
if s.RefreshToken == "" {
|
||||
return false, ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
|
||||
if err != nil {
|
||||
// When we detect that the auth provider is not explicitly denying
|
||||
// authentication, and is merely unavailable, we refresh and continue
|
||||
// as normal during the "grace period"
|
||||
if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) {
|
||||
s.RefreshDeadline = extendDeadline(p.SessionValidTTL)
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
s.AccessToken = newToken
|
||||
s.RefreshDeadline = extendDeadline(duration)
|
||||
s.GracePeriodStart = time.Time{}
|
||||
log.Info().Str("user", s.Email).Msg("proxy/authenticator.RefreshSession")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *AuthenticateClient) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("refresh_token", refreshToken)
|
||||
var req *http.Request
|
||||
req, err = p.newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := defaultHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
if isProviderUnavailable(resp.StatusCode) {
|
||||
err = ErrAuthProviderUnavailable
|
||||
} else {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RefreshURL.String(), body)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
token = data.AccessToken
|
||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||
return
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the current sessions state
|
||||
func (p *AuthenticateClient) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
// we validate the user's access token is valid
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
req, err := p.newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : error validating session state")
|
||||
return false
|
||||
}
|
||||
req.Header.Set("X-Client-Secret", p.ClientSecret)
|
||||
req.Header.Set("X-Access-Token", s.AccessToken)
|
||||
req.Header.Set("X-Id-Token", s.IDToken)
|
||||
|
||||
resp, err := defaultHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : error making request to validate access token")
|
||||
return false
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// When we detect that the auth provider is not explicitly denying
|
||||
// authentication, and is merely unavailable, we validate and continue
|
||||
// as normal during the "grace period"
|
||||
if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) {
|
||||
//tags := []string{"action:validate_session", "error:validation_failed"}
|
||||
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
|
||||
return true
|
||||
}
|
||||
log.Info().Str("user", s.Email).Int("status-code", resp.StatusCode).Msg("proxy/authenticator.ValidateSessionState : could not validate user access token")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
|
||||
s.GracePeriodStart = time.Time{}
|
||||
|
||||
log.Info().Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : validated session")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// signRedirectURL signs the redirect url string, given a timestamp, and returns it
|
||||
func (p *AuthenticateClient) signRedirectURL(rawRedirect string, timestamp time.Time) string {
|
||||
h := hmac.New(sha256.New, []byte(p.ClientSecret))
|
||||
h.Write([]byte(rawRedirect))
|
||||
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||
return base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// GetSignInURL with typical oauth parameters
|
||||
func (p *AuthenticateClient) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
|
||||
a := *p.SignInURL
|
||||
now := time.Now()
|
||||
rawRedirect := redirectURL.String()
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
params.Set("redirect_uri", rawRedirect)
|
||||
params.Set("client_id", p.ClientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Add("state", state)
|
||||
params.Set("ts", fmt.Sprint(now.Unix()))
|
||||
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
||||
a.RawQuery = params.Encode()
|
||||
return &a
|
||||
}
|
||||
|
||||
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
|
||||
func (p *AuthenticateClient) GetSignOutURL(redirectURL *url.URL) *url.URL {
|
||||
a := *p.SignOutURL
|
||||
now := time.Now()
|
||||
rawRedirect := redirectURL.String()
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
params.Add("redirect_uri", rawRedirect)
|
||||
params.Set("ts", fmt.Sprint(now.Unix()))
|
||||
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
||||
a.RawQuery = params.Encode()
|
||||
return &a
|
||||
}
|
520
proxy/handlers.go
Normal file
520
proxy/handlers.go
Normal file
|
@ -0,0 +1,520 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
const loggingUserHeader = "SSO-Authenticated-User"
|
||||
|
||||
var (
|
||||
//ErrUserNotAuthorized is set when user is not authorized to access a resource
|
||||
ErrUserNotAuthorized = errors.New("user not authorized")
|
||||
)
|
||||
|
||||
var securityHeaders = map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
}
|
||||
|
||||
// Handler returns a http handler for an Proxy
|
||||
func (p *Proxy) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/favicon.ico", p.Favicon)
|
||||
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
||||
mux.HandleFunc("/.pomerium/callback", p.OAuthCallback)
|
||||
mux.HandleFunc("/.pomerium/auth", p.AuthenticateOnly)
|
||||
mux.HandleFunc("/", p.Proxy)
|
||||
|
||||
// Global middleware, which will be applied to each request in reverse
|
||||
// order as applied here (i.e., we want to validate the host _first_ when
|
||||
// processing a request)
|
||||
var handler http.Handler = mux
|
||||
// todo(bdd) : investigate if setting non-overridable headers makes sense
|
||||
// handler = p.setResponseHeaderOverrides(handler)
|
||||
handler = middleware.SetHeaders(handler, securityHeaders)
|
||||
handler = middleware.ValidateHost(handler, p.mux)
|
||||
handler = middleware.RequireHTTPS(handler)
|
||||
handler = log.NewLoggingHandler(handler)
|
||||
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
// Skip host validation for /ping requests because they hit the LB directly.
|
||||
if req.URL.Path == "/ping" {
|
||||
p.PingPage(rw, req)
|
||||
return
|
||||
}
|
||||
handler.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
||||
func (p *Proxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
// Favicon will proxy the request as usual if the user is already authenticated
|
||||
// but responds with a 404 otherwise, to avoid spurious and confusing
|
||||
// authentication attempts when a browser automatically requests the favicon on
|
||||
// an error page.
|
||||
func (p *Proxy) Favicon(rw http.ResponseWriter, req *http.Request) {
|
||||
err := p.Authenticate(rw, req)
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
p.Proxy(rw, req)
|
||||
}
|
||||
|
||||
// PingPage send back a 200 OK response.
|
||||
func (p *Proxy) PingPage(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(rw, "OK")
|
||||
}
|
||||
|
||||
// SignOut redirects the request to the sign out url.
|
||||
func (p *Proxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
|
||||
var scheme string
|
||||
|
||||
// Build redirect URI from request host
|
||||
if req.URL.Scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
redirectURL := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: req.Host,
|
||||
Path: "/",
|
||||
}
|
||||
fullURL := p.authenticateClient.GetSignOutURL(redirectURL)
|
||||
http.Redirect(rw, req, fullURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// XHRError returns a simple error response with an error message to the application if the request is an XML request
|
||||
func (p *Proxy) XHRError(rw http.ResponseWriter, req *http.Request, code int, err error) {
|
||||
jsonError := struct {
|
||||
Error error `json:"error"`
|
||||
}{
|
||||
Error: err,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(jsonError)
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
requestLog := log.WithRequest(req, "proxy.ErrorPage")
|
||||
requestLog.Error().Err(err).Int("http-status", code).Msg("proxy.XHRError")
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(code)
|
||||
rw.Write(jsonBytes)
|
||||
}
|
||||
|
||||
// ErrorPage renders an error page with a given status code, title, and message.
|
||||
func (p *Proxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) {
|
||||
if p.isXHR(req) {
|
||||
p.XHRError(rw, req, code, errors.New(message))
|
||||
return
|
||||
}
|
||||
requestLog := log.WithRequest(req, "proxy.ErrorPage")
|
||||
requestLog.Info().
|
||||
Str("page-title", title).
|
||||
Str("page-message", message).
|
||||
Msg("proxy.ErrorPage")
|
||||
|
||||
rw.WriteHeader(code)
|
||||
t := struct {
|
||||
Code int
|
||||
Title string
|
||||
Message string
|
||||
Version string
|
||||
}{
|
||||
Code: code,
|
||||
Title: title,
|
||||
Message: message,
|
||||
Version: version.FullVersion(),
|
||||
}
|
||||
p.templates.ExecuteTemplate(rw, "error.html", t)
|
||||
}
|
||||
|
||||
func (p *Proxy) isXHR(req *http.Request) bool {
|
||||
return req.Header.Get("X-Requested-With") == "XMLHttpRequest"
|
||||
}
|
||||
|
||||
// OAuthStart begins the authentication flow, encrypting the redirect url
|
||||
// in a request to the provider's sign in endpoint.
|
||||
func (p *Proxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
// The proxy redirects to the authenticator, and provides it with redirectURI (which points
|
||||
// back to the sso proxy).
|
||||
requestLog := log.WithRequest(req, "proxy.OAuthStart")
|
||||
|
||||
if p.isXHR(req) {
|
||||
e := errors.New("cannot continue oauth flow on xhr")
|
||||
requestLog.Error().Err(e).Msg("isXHR")
|
||||
p.XHRError(rw, req, http.StatusUnauthorized, e)
|
||||
return
|
||||
}
|
||||
|
||||
requestURI := req.URL.String()
|
||||
callbackURL := p.GetRedirectURL(req.Host)
|
||||
|
||||
// generate nonce
|
||||
key := aead.GenerateKey()
|
||||
|
||||
// state prevents cross site forgery and maintain state across the client and server
|
||||
state := &StateParameter{
|
||||
SessionID: fmt.Sprintf("%x", key), // nonce
|
||||
RedirectURI: requestURI, // where to redirect the user back to
|
||||
}
|
||||
|
||||
// we encrypt this value to be opaque the browser cookie
|
||||
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||
encryptedCSRF, err := p.CookieCipher.Marshal(state)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("failed to marshal csrf")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
p.csrfStore.SetCSRF(rw, req, encryptedCSRF)
|
||||
|
||||
// we encrypt this value to be opaque the uri query value
|
||||
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||
encryptedState, err := p.CookieCipher.Marshal(state)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("failed to encrypt cookie")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
signinURL := p.authenticateClient.GetSignInURL(callbackURL, encryptedState)
|
||||
requestLog.Info().Msg("redirecting to begin auth flow")
|
||||
http.Redirect(rw, req, signinURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// OAuthCallback validates the cookie sent back from the provider, then validates
|
||||
// the user information, and if authorized, redirects the user back to the original
|
||||
// application.
|
||||
func (p *Proxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
// We receive the callback from the SSO Authenticator. This request will either contain an
|
||||
// error, or it will contain a `code`; the code can be used to fetch an access token, and
|
||||
// other metadata, from the authenticator.
|
||||
requestLog := log.WithRequest(req, "proxy.OAuthCallback")
|
||||
// finish the oauth cycle
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("failed parsing request form")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorString)
|
||||
return
|
||||
}
|
||||
|
||||
// We begin the process of redeeming the code for an access token.
|
||||
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("error redeeming authorization code")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
|
||||
encryptedState := req.Form.Get("state")
|
||||
stateParameter := &StateParameter{}
|
||||
err = p.CookieCipher.Unmarshal(encryptedState, stateParameter)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("could not unmarshal state")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
|
||||
c, err := req.Cookie(p.CSRFCookieName)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("failed parsing csrf cookie")
|
||||
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", err.Error())
|
||||
return
|
||||
}
|
||||
p.csrfStore.ClearCSRF(rw, req)
|
||||
|
||||
encryptedCSRF := c.Value
|
||||
csrfParameter := &StateParameter{}
|
||||
err = p.CookieCipher.Unmarshal(encryptedCSRF, csrfParameter)
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Msg("couldn't unmarshal CSRF")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
|
||||
if encryptedState == encryptedCSRF {
|
||||
requestLog.Error().Msg("encrypted state and CSRF should not be equal")
|
||||
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
||||
requestLog.Error().Msg("state and CSRF should be equal")
|
||||
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
|
||||
return
|
||||
}
|
||||
|
||||
// We validate the user information, and check that this user has proper authorization
|
||||
// for the resources requested. This can be set via the email address or any groups.
|
||||
//
|
||||
// set cookie, or deny
|
||||
if !p.EmailValidator(session.Email) {
|
||||
requestLog.Error().Str("user", session.Email).Msg("permission denied: unauthorized")
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "Invalid Account")
|
||||
return
|
||||
}
|
||||
|
||||
// We store the session in a cookie and redirect the user back to the application
|
||||
err = p.sessionStore.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
requestLog.Error().Msg("error saving session")
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
|
||||
// This is the redirect back to the original requested application
|
||||
http.Redirect(rw, req, stateParameter.RedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// AuthenticateOnly calls the Authenticate handler.
|
||||
func (p *Proxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
|
||||
err := p.Authenticate(rw, req)
|
||||
if err != nil {
|
||||
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
|
||||
}
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
// Proxy authenticates a request, either proxying the request if it is authenticated, or starting the authentication process if not.
|
||||
func (p *Proxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
// Attempts to validate the user and their cookie.
|
||||
// start := time.Now()
|
||||
var err error
|
||||
err = p.Authenticate(rw, req)
|
||||
// If the authentication is not successful we proceed to start the OAuth Flow with
|
||||
// OAuthStart. If authentication is successful, we proceed to proxy to the configured
|
||||
// upstream.
|
||||
requestLog := log.WithRequest(req, "proxy.Proxy")
|
||||
if err != nil {
|
||||
switch err {
|
||||
case http.ErrNoCookie:
|
||||
// No cookie is set, start the oauth flow
|
||||
p.OAuthStart(rw, req)
|
||||
return
|
||||
case ErrUserNotAuthorized:
|
||||
// We know the user is not authorized for the request, we show them a forbidden page
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
|
||||
return
|
||||
case sessions.ErrLifetimeExpired:
|
||||
// User's lifetime expired, we trigger the start of the oauth flow
|
||||
p.OAuthStart(rw, req)
|
||||
return
|
||||
case sessions.ErrInvalidSession:
|
||||
// The user session is invalid and we can't decode it.
|
||||
// This can happen for a variety of reasons but the most common non-malicious
|
||||
// case occurs when the session encoding schema changes. We manage this ux
|
||||
// by triggering the start of the oauth flow.
|
||||
p.OAuthStart(rw, req)
|
||||
return
|
||||
default:
|
||||
requestLog.Error().Err(err).Msg("unknown error")
|
||||
// We don't know exactly what happened, but authenticating the user failed, show an error
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "An unexpected error occurred")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// We have validated the users request and now proxy their request to the provided upstream.
|
||||
route, ok := p.router(req)
|
||||
if !ok {
|
||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// overhead := time.Now().Sub(start)
|
||||
route.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
|
||||
// clearing the session cookie if it's invalid and returning an error if necessary..
|
||||
func (p *Proxy) Authenticate(rw http.ResponseWriter, req *http.Request) (err error) {
|
||||
|
||||
// Clear the session cookie if anything goes wrong.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
}
|
||||
}()
|
||||
requestLog := log.WithRequest(req, "proxy.Authenticate")
|
||||
|
||||
session, err := p.sessionStore.LoadSession(req)
|
||||
if err != nil {
|
||||
// We loaded a cookie but it wasn't valid, clear it, and reject the request
|
||||
requestLog.Error().Err(err).Msg("error authenticating user")
|
||||
return err
|
||||
}
|
||||
|
||||
// Lifetime period is the entire duration in which the session is valid.
|
||||
// This should be set to something like 14 to 30 days.
|
||||
if session.LifetimePeriodExpired() {
|
||||
requestLog.Warn().Str("user", session.Email).Msg("session lifetime has expired")
|
||||
return sessions.ErrLifetimeExpired
|
||||
} else if session.RefreshPeriodExpired() {
|
||||
// Refresh period is the period in which the access token is valid. This is ultimately
|
||||
// controlled by the upstream provider and tends to be around 1 hour.
|
||||
ok, err := p.authenticateClient.RefreshSession(session)
|
||||
|
||||
// We failed to refresh the session successfully
|
||||
// clear the cookie and reject the request
|
||||
if err != nil {
|
||||
requestLog.Error().Err(err).Str("user", session.Email).Msg("refreshing session failed")
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
// User is not authorized after refresh
|
||||
// clear the cookie and reject the request
|
||||
requestLog.Error().Str("user", session.Email).Msg("not authorized after refreshing session")
|
||||
return ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
err = p.sessionStore.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
// We refreshed the session successfully, but failed to save it.
|
||||
//
|
||||
// This could be from failing to encode the session properly.
|
||||
// But, we clear the session cookie and reject the request!
|
||||
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save refresh session")
|
||||
return err
|
||||
}
|
||||
} else if session.ValidationPeriodExpired() {
|
||||
// Validation period has expired, this is the shortest interval we use to
|
||||
// check for valid requests. This should be set to something like a minute.
|
||||
// This calls up the provider chain to validate this user is still active
|
||||
// and hasn't been de-authorized.
|
||||
ok := p.authenticateClient.ValidateSessionState(session)
|
||||
if !ok {
|
||||
// This user is now no longer authorized, or we failed to
|
||||
// validate the user.
|
||||
// Clear the cookie and reject the request
|
||||
requestLog.Error().Str("user", session.Email).Msg("no longer authorized after validation period")
|
||||
return ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
err = p.sessionStore.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
// We validated the session successfully, but failed to save it.
|
||||
|
||||
// This could be from failing to encode the session properly.
|
||||
// But, we clear the session cookie and reject the request!
|
||||
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save validated session")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !p.EmailValidator(session.Email) {
|
||||
requestLog.Error().Str("user", session.Email).Msg("email failed to validate, unauthorized")
|
||||
return ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", session.User)
|
||||
|
||||
if p.PassAccessToken && session.AccessToken != "" {
|
||||
req.Header.Set("X-Forwarded-Access-Token", session.AccessToken)
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-Email", session.Email)
|
||||
req.Header.Set("X-Forwarded-Groups", strings.Join(session.Groups, ","))
|
||||
|
||||
// stash authenticated user so that it can be logged later (see func logRequest)
|
||||
rw.Header().Set(loggingUserHeader, session.Email)
|
||||
|
||||
// This user has been OK'd. Allow the request!
|
||||
return nil
|
||||
}
|
||||
|
||||
// upstreamTransport is used to ensure that upstreams cannot override the
|
||||
// security headers applied by sso_proxy
|
||||
type upstreamTransport struct {
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
// RoundTrip round trips the request and deletes security headers before returning the response.
|
||||
func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := t.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("proxy.RoundTrip")
|
||||
return nil, err
|
||||
}
|
||||
for key := range securityHeaders {
|
||||
resp.Header.Del(key)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
|
||||
func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||
p.mux[host] = &handler
|
||||
}
|
||||
|
||||
// router attempts to find a route for a request. If a route is successfully matched,
|
||||
// it returns the route information and a bool value of `true`. If a route can not be matched,
|
||||
//a nil value for the route and false bool value is returned.
|
||||
func (p *Proxy) router(req *http.Request) (http.Handler, bool) {
|
||||
route, ok := p.mux[req.Host]
|
||||
if ok {
|
||||
return *route, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetRedirectURL returns the redirect url for a given Proxy,
|
||||
// setting the scheme to be https if CookieSecure is true.
|
||||
func (p *Proxy) GetRedirectURL(host string) *url.URL {
|
||||
// TODO: Ensure that we only allow valid upstream hosts in redirect URIs
|
||||
u := p.redirectURL
|
||||
// Build redirect URI from request host
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
u.Host = host
|
||||
return u
|
||||
}
|
||||
|
||||
func (p *Proxy) redeemCode(host, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
redirectURL := p.GetRedirectURL(host)
|
||||
s, err := p.authenticateClient.Redeem(redirectURL.String(), code)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
if s.Email == "" {
|
||||
return s, errors.New("invalid email address")
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
331
proxy/proxy.go
Executable file
331
proxy/proxy.go
Executable file
|
@ -0,0 +1,331 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/envconfig"
|
||||
"github.com/pomerium/pomerium/internal/aead"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||
)
|
||||
|
||||
// Options represents the configuration options for the proxy service.
|
||||
type Options struct {
|
||||
// AuthenticateServiceURL specifies the url to the pomerium authenticate http service.
|
||||
AuthenticateServiceURL *url.URL `envconfig:"PROVIDER_URL"`
|
||||
|
||||
// EmailDomains is a string slice of valid domains to proxy
|
||||
EmailDomains []string `envconfig:"EMAIL_DOMAIN"`
|
||||
// todo(bdd): ClientID and ClientSecret are used are a hacky pre shared key
|
||||
// prefer certificates and mutual tls
|
||||
ClientID string `envconfig:"PROXY_CLIENT_ID"`
|
||||
ClientSecret string `envconfig:"PROXY_CLIENT_SECRET"`
|
||||
|
||||
DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT"`
|
||||
|
||||
CookieName string `envconfig:"COOKIE_NAME"`
|
||||
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
|
||||
CookieSecure bool `envconfig:"COOKIE_SECURE" `
|
||||
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
|
||||
|
||||
PassAccessToken bool `envconfig:"PASS_ACCESS_TOKEN"`
|
||||
|
||||
// session details
|
||||
SessionValidTTL time.Duration `envconfig:"SESSION_VALID_TTL"`
|
||||
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL"`
|
||||
GracePeriodTTL time.Duration `envconfig:"GRACE_PERIOD_TTL"`
|
||||
|
||||
Routes map[string]string `envconfig:"ROUTES"`
|
||||
}
|
||||
|
||||
// NewOptions returns a new options struct
|
||||
var defaultOptions = &Options{
|
||||
CookieName: "_pomerium_proxy",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
DefaultUpstreamTimeout: time.Duration(10) * time.Second,
|
||||
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
SessionValidTTL: time.Duration(1) * time.Minute,
|
||||
GracePeriodTTL: time.Duration(3) * time.Hour,
|
||||
PassAccessToken: false,
|
||||
}
|
||||
|
||||
// OptionsFromEnvConfig builds the authentication service's configuration
|
||||
// options from provided environmental variables
|
||||
func OptionsFromEnvConfig() (*Options, error) {
|
||||
o := defaultOptions
|
||||
if err := envconfig.Process("", o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// Validate checks that proper configuration settings are set to create
|
||||
// a proper Proxy instance
|
||||
func (o *Options) Validate() error {
|
||||
if len(o.Routes) == 0 {
|
||||
return errors.New("missing setting: routes")
|
||||
}
|
||||
for to, from := range o.Routes {
|
||||
if _, err := urlParse(to); err != nil {
|
||||
return fmt.Errorf("could not parse origin %s as url : %q", to, err)
|
||||
}
|
||||
if _, err := urlParse(from); err != nil {
|
||||
return fmt.Errorf("could not parse destination %s as url : %q", to, err)
|
||||
}
|
||||
}
|
||||
if o.AuthenticateServiceURL == nil {
|
||||
return errors.New("missing setting: provider-url")
|
||||
}
|
||||
if o.CookieSecret == "" {
|
||||
return errors.New("missing setting: cookie-secret")
|
||||
}
|
||||
if o.ClientID == "" {
|
||||
return errors.New("missing setting: client-id")
|
||||
}
|
||||
if o.ClientSecret == "" {
|
||||
return errors.New("missing setting: client-secret")
|
||||
}
|
||||
if len(o.EmailDomains) == 0 {
|
||||
return errors.New("missing setting: email-domain")
|
||||
}
|
||||
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||
if err != nil {
|
||||
return errors.New("cookie secret is invalid (e.g. `head -c33 /dev/urandom | base64`) ")
|
||||
}
|
||||
validCookieSecretLength := false
|
||||
for _, i := range []int{32, 64} {
|
||||
if len(decodedCookieSecret) == i {
|
||||
validCookieSecretLength = true
|
||||
}
|
||||
}
|
||||
if !validCookieSecretLength {
|
||||
return fmt.Errorf("cookie secret is invalid, must be 32 or 64 bytes but got %d bytes (e.g. `head -c33 /dev/urandom | base64`) ", len(decodedCookieSecret))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Proxy stores all the information associated with proxying a request.
|
||||
type Proxy struct {
|
||||
CookieCipher aead.Cipher
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieName string
|
||||
CookieSecure bool
|
||||
CookieSeed string
|
||||
CSRFCookieName string
|
||||
EmailValidator func(string) bool
|
||||
|
||||
PassAccessToken bool
|
||||
|
||||
// services
|
||||
authenticateClient *authenticator.AuthenticateClient
|
||||
// session
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
cipher aead.Cipher
|
||||
|
||||
redirectURL *url.URL // the url to receive requests at
|
||||
templates *template.Template
|
||||
mux map[string]*http.Handler
|
||||
}
|
||||
|
||||
// StateParameter holds the redirect id along with the session id.
|
||||
type StateParameter struct {
|
||||
SessionID string `json:"session_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
// NewProxy takes a Proxy service from options and a validation function.
|
||||
// Function returns an error if options fail to validate.
|
||||
func NewProxy(opts *Options, optFuncs ...func(*Proxy) error) (*Proxy, error) {
|
||||
if opts == nil {
|
||||
return nil, errors.New("options cannot be nil")
|
||||
}
|
||||
if err := opts.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// error explicitly handled by validate
|
||||
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cipher, err := aead.NewMiscreantCipher(decodedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
|
||||
sessions.CreateMiscreantCookieCipher(decodedSecret),
|
||||
func(c *sessions.CookieStore) error {
|
||||
c.CookieDomain = opts.CookieDomain
|
||||
c.CookieHTTPOnly = opts.CookieHTTPOnly
|
||||
c.CookieExpire = opts.CookieExpire
|
||||
c.CookieSecure = opts.CookieSecure
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authClient := authenticator.NewAuthenticateClient(
|
||||
opts.AuthenticateServiceURL,
|
||||
// todo(bdd): fields below can be dropped as Client data provides redudent auth
|
||||
opts.ClientID,
|
||||
opts.ClientSecret,
|
||||
// todo(bdd): fields below should be passed as function args
|
||||
opts.SessionLifetimeTTL,
|
||||
opts.SessionValidTTL,
|
||||
opts.GracePeriodTTL,
|
||||
)
|
||||
|
||||
p := &Proxy{
|
||||
CookieCipher: cipher,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieName: opts.CookieName,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieSeed: string(decodedSecret),
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
||||
|
||||
// these fields make up the routing mechanism
|
||||
mux: make(map[string]*http.Handler),
|
||||
// session state
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
cipher: cipher,
|
||||
|
||||
authenticateClient: authClient,
|
||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||
templates: templates.New(),
|
||||
PassAccessToken: opts.PassAccessToken,
|
||||
}
|
||||
|
||||
for _, optFunc := range optFuncs {
|
||||
err := optFunc(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for from, to := range opts.Routes {
|
||||
fromURL, _ := urlParse(from)
|
||||
toURL, _ := urlParse(to)
|
||||
reverseProxy := NewReverseProxy(toURL)
|
||||
handler := NewReverseProxyHandler(opts, reverseProxy, toURL.String())
|
||||
p.Handle(fromURL.Host, handler)
|
||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.NewProxy : route")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("CookieName", p.CookieName).
|
||||
Str("redirectURL", p.redirectURL.String()).
|
||||
Str("CSRFCookieName", p.CSRFCookieName).
|
||||
Bool("CookieSecure", p.CookieSecure).
|
||||
Str("CookieDomain", p.CookieDomain).
|
||||
Bool("CookieHTTPOnly", p.CookieHTTPOnly).
|
||||
Dur("CookieExpire", opts.CookieExpire).
|
||||
Msg("proxy.NewProxy")
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// UpstreamProxy stores information necessary for proxying the request back to the upstream.
|
||||
type UpstreamProxy struct {
|
||||
name string
|
||||
cookieName string
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
var defaultUpstreamTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
// deleteSSOCookieHeader deletes the session cookie from the request header string.
|
||||
func deleteSSOCookieHeader(req *http.Request, cookieName string) {
|
||||
headers := []string{}
|
||||
for _, cookie := range req.Cookies() {
|
||||
if cookie.Name != cookieName {
|
||||
headers = append(headers, cookie.String())
|
||||
}
|
||||
}
|
||||
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||
}
|
||||
|
||||
// ServeHTTP signs the http request and deletes cookie headers
|
||||
// before calling the upstream's ServeHTTP function.
|
||||
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
requestLog := log.WithRequest(r, "proxy.ServeHTTP")
|
||||
deleteSSOCookieHeader(r, u.cookieName)
|
||||
start := time.Now()
|
||||
u.handler.ServeHTTP(w, r)
|
||||
duration := time.Since(start)
|
||||
|
||||
requestLog.Debug().Dur("duration", duration).Msg("proxy-request")
|
||||
}
|
||||
|
||||
// NewReverseProxy creates a reverse proxy to a specified url.
|
||||
// It adds an X-Forwarded-Host header that is the request's host.
|
||||
//
|
||||
// todo(bdd): when would we ever want to preserve host?
|
||||
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||
proxy := httputil.NewSingleHostReverseProxy(to)
|
||||
proxy.Transport = defaultUpstreamTransport
|
||||
director := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Header.Add("X-Forwarded-Host", req.Host)
|
||||
director(req)
|
||||
req.Host = to.Host
|
||||
}
|
||||
return proxy
|
||||
}
|
||||
|
||||
// NewReverseProxyHandler applies handler specific options to a given
|
||||
// route.
|
||||
func NewReverseProxyHandler(opts *Options, reverseProxy *httputil.ReverseProxy, serviceName string) http.Handler {
|
||||
upstreamProxy := &UpstreamProxy{
|
||||
name: serviceName,
|
||||
handler: reverseProxy,
|
||||
cookieName: opts.CookieName,
|
||||
}
|
||||
|
||||
timeout := opts.DefaultUpstreamTimeout
|
||||
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", serviceName, timeout)
|
||||
return http.TimeoutHandler(upstreamProxy, timeout, timeoutMsg)
|
||||
}
|
||||
|
||||
// urlParse adds a scheme if none-exists, addressesing a quirk in how
|
||||
// one may expect url.Parse to function when a "naked" domain is sent.
|
||||
//
|
||||
// see: https://github.com/golang/go/issues/12585
|
||||
// see: https://golang.org/pkg/net/url/#Parse
|
||||
func urlParse(uri string) (*url.URL, error) {
|
||||
if !strings.Contains(uri, "://") {
|
||||
uri = fmt.Sprintf("https://%s", uri)
|
||||
}
|
||||
return url.Parse(uri)
|
||||
}
|
220
proxy/proxy_test.go
Normal file
220
proxy/proxy_test.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want *Options
|
||||
envKey string
|
||||
envValue string
|
||||
wantErr bool
|
||||
}{
|
||||
{"good default, no env settings", defaultOptions, "", "", false},
|
||||
{"bad url", nil, "PROVIDER_URL", "%.rjlw", true},
|
||||
{"good duration", defaultOptions, "SESSION_VALID_TTL", "1m", false},
|
||||
{"bad duration", nil, "SESSION_VALID_TTL", "1sm", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envKey != "" {
|
||||
os.Setenv(tt.envKey, tt.envValue)
|
||||
defer os.Unsetenv(tt.envKey)
|
||||
}
|
||||
got, err := OptionsFromEnvConfig()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OptionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("OptionsFromEnvConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_urlParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
want *url.URL
|
||||
wantErr bool
|
||||
}{
|
||||
{"good url without schema", "accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
|
||||
{"good url with schema", "https://accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
|
||||
{"bad url, malformed", "https://accounts.google.^", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := urlParse(tt.uri)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("urlParse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("urlParse() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReverseProxy(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
w.Write([]byte(hostname))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||
|
||||
proxyHandler := NewReverseProxy(proxyURL)
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
defer frontend.Close()
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReverseProxyHandler(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
w.Write([]byte(hostname))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||
|
||||
proxyHandler := NewReverseProxy(proxyURL)
|
||||
opts := defaultOptions
|
||||
handle := NewReverseProxyHandler(opts, proxyHandler, "name")
|
||||
|
||||
frontend := httptest.NewServer(handle)
|
||||
|
||||
defer frontend.Close()
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testOptions() *Options {
|
||||
authurl, _ := url.Parse("https://sso-auth.corp.beyondperimeter.com")
|
||||
return &Options{
|
||||
Routes: map[string]string{"corp.example.com": "example.com"},
|
||||
AuthenticateServiceURL: authurl,
|
||||
ClientID: "yksYDhIM7PZTvdFP3Mi3sYt2JXooTi7y0oIClBR46fs=",
|
||||
ClientSecret: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
EmailDomains: []string{"*"},
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptions_Validate(t *testing.T) {
|
||||
good := testOptions()
|
||||
badFromRoute := testOptions()
|
||||
badFromRoute.Routes = map[string]string{"example.com": "^"}
|
||||
badToRoute := testOptions()
|
||||
badToRoute.Routes = map[string]string{"^": "example.com"}
|
||||
badAuthURL := testOptions()
|
||||
badAuthURL.AuthenticateServiceURL = nil
|
||||
emptyCookieSecret := testOptions()
|
||||
emptyCookieSecret.CookieSecret = ""
|
||||
invalidCookieSecret := testOptions()
|
||||
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
||||
shortCookieLength := testOptions()
|
||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" //head -c31 /dev/urandom | base64
|
||||
|
||||
badClientID := testOptions()
|
||||
badClientID.ClientID = ""
|
||||
badClientSecret := testOptions()
|
||||
badClientSecret.ClientSecret = ""
|
||||
badEmailDomain := testOptions()
|
||||
badEmailDomain.EmailDomains = nil
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
o *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"good - minimum options", good, false},
|
||||
|
||||
{"bad - nil options", &Options{}, true},
|
||||
{"bad - from route", badFromRoute, true},
|
||||
{"bad - to route", badToRoute, true},
|
||||
{"bad - auth service url", badAuthURL, true},
|
||||
{"bad - no cookie secret", emptyCookieSecret, true},
|
||||
{"bad - invalid cookie secret", invalidCookieSecret, true},
|
||||
{"bad - short cookie secret", shortCookieLength, true},
|
||||
{"bad - no client id", badClientID, true},
|
||||
{"bad - no client secret", badClientSecret, true},
|
||||
{"bad - no email domain", badEmailDomain, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := tt.o
|
||||
if err := o.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProxy(t *testing.T) {
|
||||
good := testOptions()
|
||||
shortCookieLength := testOptions()
|
||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" //head -c31 /dev/urandom | base64
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
optFuncs []func(*Proxy) error
|
||||
wantProxy bool
|
||||
numMuxes int
|
||||
wantErr bool
|
||||
}{
|
||||
{"good - minimum options", good, nil, true, 1, false},
|
||||
{"bad - empty options", &Options{}, nil, false, 0, true},
|
||||
{"bad - nil options", nil, nil, false, 0, true},
|
||||
{"bad - short secret/validate sanity check", shortCookieLength, nil, false, 0, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewProxy(tt.opts, tt.optFuncs...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewProxy() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got == nil && tt.wantProxy == true {
|
||||
t.Errorf("NewProxy() expected valid proxy struct")
|
||||
}
|
||||
if got != nil && len(got.mux) != tt.numMuxes {
|
||||
t.Errorf("NewProxy() = num muxes %v, want %v", got, tt.numMuxes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
8
proxy/testdata/public_key.pub
vendored
Normal file
8
proxy/testdata/public_key.pub
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
-----BEGIN RSA PUBLIC KEY-----
|
||||
MIIBCgKCAQEAst/CEAh/EMnjRbNcwNF7iMqp03En2GYNJz3wfiv/6Rcu7SDgMJke
|
||||
rYfDcpK8RYREAxyjQpi17eI/FRQx0GbRo1AR0ZgF2VvDTkNBCNb3Pw6bdPbFONCU
|
||||
JV2WXi/vf+4gMRH0hN00K9ZOz18MaY5va7C0p+xaC5713KNJnOvndo48X+HDICSG
|
||||
kCyjne/NylEMy1RLwUCdOSZ6SNsTI0tKt95bTEzBhd0GUDfYuG2SoJyLaJisUyW3
|
||||
8X7TtdRUzSwe6IPeLFppU4QGOf1DI2WlmCdYPPfllCfiqVWMibBzwQZGkBvjWGs3
|
||||
Cw8iKMKcydVlZCJ8rLIaU6sE/lD1eGrfowIDAQAB
|
||||
-----END RSA PUBLIC KEY-----
|
10
proxy/testdata/upstream_configs.yml
vendored
Normal file
10
proxy/testdata/upstream_configs.yml
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
- service: foo
|
||||
default:
|
||||
from: foo.{{cluster}}.{{root_domain}}
|
||||
to: foo-internal.{{cluster}}.{{root_domain}}
|
||||
options:
|
||||
allowed_groups:
|
||||
- dev
|
||||
dev:
|
||||
from: foo.{{root_domain}}
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue