mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
initial release
This commit is contained in:
commit
d56c889224
62 changed files with 8229 additions and 0 deletions
96
.gitignore
vendored
Normal file
96
.gitignore
vendored
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
|
||||||
|
pem
|
||||||
|
env
|
||||||
|
coverage.txt
|
||||||
|
*.pem
|
||||||
|
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||||
|
*.o
|
||||||
|
*.a
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Folders
|
||||||
|
_obj
|
||||||
|
_test
|
||||||
|
.cover
|
||||||
|
|
||||||
|
# Architecture specific extensions/prefixes
|
||||||
|
*.[568vq]
|
||||||
|
[568vq].out
|
||||||
|
|
||||||
|
*.cgo1.go
|
||||||
|
*.cgo2.c
|
||||||
|
_cgo_defun.c
|
||||||
|
_cgo_gotypes.go
|
||||||
|
_cgo_export.*
|
||||||
|
|
||||||
|
_testmain.go
|
||||||
|
|
||||||
|
*.exe
|
||||||
|
*.test
|
||||||
|
*.prof
|
||||||
|
|
||||||
|
# Other dirs
|
||||||
|
/bin/
|
||||||
|
/pkg/
|
||||||
|
|
||||||
|
# Without this, the *.[568vq] above ignores this folder.
|
||||||
|
!**/graphrbac/1.6
|
||||||
|
|
||||||
|
# Ruby
|
||||||
|
website/vendor
|
||||||
|
website/.bundle
|
||||||
|
website/build
|
||||||
|
website/tmp
|
||||||
|
|
||||||
|
# Vagrant
|
||||||
|
.vagrant/
|
||||||
|
Vagrantfile
|
||||||
|
|
||||||
|
# Configs
|
||||||
|
*.hcl
|
||||||
|
!command/agent/config/test-fixtures/config.hcl
|
||||||
|
!command/agent/config/test-fixtures/config-embedded-type.hcl
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
dist/*
|
||||||
|
|
||||||
|
tags
|
||||||
|
|
||||||
|
# Editor backups
|
||||||
|
*~
|
||||||
|
*.sw[a-z]
|
||||||
|
|
||||||
|
# IntelliJ IDEA project files
|
||||||
|
.idea
|
||||||
|
*.ipr
|
||||||
|
*.iml
|
||||||
|
|
||||||
|
# compiled output
|
||||||
|
ui/dist
|
||||||
|
ui/tmp
|
||||||
|
ui/root
|
||||||
|
http/bindata_assetfs.go
|
||||||
|
|
||||||
|
# dependencies
|
||||||
|
ui/node_modules
|
||||||
|
ui/bower_components
|
||||||
|
|
||||||
|
# misc
|
||||||
|
ui/.DS_Store
|
||||||
|
ui/.sass-cache
|
||||||
|
ui/connect.lock
|
||||||
|
ui/coverage/*
|
||||||
|
ui/libpeerconnection.log
|
||||||
|
ui/npm-debug.log
|
||||||
|
ui/testem.log
|
||||||
|
|
||||||
|
# used for JS acceptance tests
|
||||||
|
ui/tests/helpers/vault-keys.js
|
||||||
|
ui/vault-ui-integration-server.pid
|
||||||
|
|
||||||
|
# for building static assets
|
||||||
|
node_modules
|
||||||
|
package-lock.json
|
18
.travis.yml
Normal file
18
.travis.yml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
---
|
||||||
|
language: go
|
||||||
|
go:
|
||||||
|
- 1.x
|
||||||
|
- tip
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
||||||
|
fast_finish: true
|
||||||
|
install:
|
||||||
|
- go get github.com/golang/lint/golint
|
||||||
|
- go get honnef.co/go/tools/cmd/staticcheck
|
||||||
|
script:
|
||||||
|
- env GO111MODULE=on make all
|
||||||
|
- env GO111MODULE=on make cover
|
||||||
|
- env GO111MODULE=on make release
|
||||||
|
# after_success:
|
||||||
|
# - bash <(curl -s https://codecov.io/bash)
|
88
3RD-PARTY
Normal file
88
3RD-PARTY
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
Third Party Licenses
|
||||||
|
|
||||||
|
Go
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
https://golang.org/LICENSE
|
||||||
|
|
||||||
|
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
Upspin
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
https://github.com/upspin/upspin/blob/master/LICENSE
|
||||||
|
|
||||||
|
Copyright (c) 2016 The Upspin Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
buzzfeed/sso (fork of bitly/oauth2_proxy)
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
https://github.com/buzzfeed/sso/blob/master/LICENSE
|
||||||
|
https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
14
Dockerfile
Normal file
14
Dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
FROM golang:alpine as build
|
||||||
|
RUN apk --update --no-cache add ca-certificates git make
|
||||||
|
ENV CGO_ENABLED=0
|
||||||
|
ENV GO111MODULE=on
|
||||||
|
|
||||||
|
WORKDIR /go/src/github.com/pomerium/pomerium
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN make
|
||||||
|
|
||||||
|
FROM scratch
|
||||||
|
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||||
|
WORKDIR /pomerium
|
||||||
|
COPY --from=build /bin/* /bin/
|
201
LICENSE
Normal file
201
LICENSE
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2019 Pomerium
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
113
Makefile
Normal file
113
Makefile
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
# Setup name variables for the package/tool
|
||||||
|
PREFIX?=$(shell pwd)
|
||||||
|
|
||||||
|
NAME := pomerium
|
||||||
|
PKG := github.com/pomerium/$(NAME)
|
||||||
|
|
||||||
|
BUILDDIR := ${PREFIX}/dist
|
||||||
|
BINDIR := ${PREFIX}/bin
|
||||||
|
GO111MODULE=on
|
||||||
|
CGO_ENABLED := 0
|
||||||
|
# Set any default go build tags
|
||||||
|
BUILDTAGS :=
|
||||||
|
|
||||||
|
# Populate version variables
|
||||||
|
# Add to compile time flags
|
||||||
|
VERSION := $(shell cat VERSION)
|
||||||
|
GITCOMMIT := $(shell git rev-parse --short HEAD)
|
||||||
|
GITUNTRACKEDCHANGES := $(shell git status --porcelain --untracked-files=no)
|
||||||
|
BUILDMETA:=
|
||||||
|
ifneq ($(GITUNTRACKEDCHANGES),)
|
||||||
|
BUILDMETA := dirty
|
||||||
|
endif
|
||||||
|
CTIMEVAR=-X $(PKG)/internal/version.GitCommit=$(GITCOMMIT) \
|
||||||
|
-X $(PKG)/internal/version.Version=$(VERSION) \
|
||||||
|
-X $(PKG)/internal/version.BuildMeta=$(BUILDMETA) \
|
||||||
|
-X $(PKG)/internal/version.ProjectName=$(NAME) \
|
||||||
|
-X $(PKG)/internal/version.ProjectURL=$(PKG)
|
||||||
|
GO_LDFLAGS=-ldflags "-w $(CTIMEVAR)"
|
||||||
|
GOOSARCHES = linux/amd64 darwin/amd64 windows/amd64
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: all
|
||||||
|
all: clean build fmt lint vet test ## Runs a clean, build, fmt, lint, test, and vet.
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: build
|
||||||
|
build: ## Builds dynamic executables and/or packages.
|
||||||
|
@echo "==> $@"
|
||||||
|
@CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)"
|
||||||
|
|
||||||
|
.PHONY: fmt
|
||||||
|
fmt: ## Verifies all files have been `gofmt`ed.
|
||||||
|
@echo "==> $@"
|
||||||
|
@gofmt -s -l . | grep -v '.pb.go:' | grep -v vendor | tee /dev/stderr
|
||||||
|
|
||||||
|
.PHONY: lint
|
||||||
|
lint: ## Verifies `golint` passes.
|
||||||
|
@echo "==> $@"
|
||||||
|
@golint ./... | grep -v '.pb.go:' | grep -v vendor | tee /dev/stderr
|
||||||
|
|
||||||
|
.PHONY: staticcheck
|
||||||
|
staticcheck: ## Verifies `staticcheck` passes
|
||||||
|
@echo "+ $@"
|
||||||
|
@staticcheck $(shell go list ./... | grep -v vendor) | grep -v '.pb.go:' | tee /dev/stderr
|
||||||
|
|
||||||
|
.PHONY: vet
|
||||||
|
vet: ## Verifies `go vet` passes.
|
||||||
|
@echo "==> $@"
|
||||||
|
@go vet $(shell go list ./... | grep -v vendor) | grep -v '.pb.go:' | tee /dev/stderr
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: ## Runs the go tests.
|
||||||
|
@echo "==> $@"
|
||||||
|
@go test -tags "$(BUILDTAGS)" $(shell go list ./... | grep -v vendor)
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: cover
|
||||||
|
cover: ## Runs go test with coverage
|
||||||
|
@echo "" > coverage.txt
|
||||||
|
@for d in $(shell go list ./... | grep -v vendor); do \
|
||||||
|
go test -race -coverprofile=profile.out -covermode=atomic "$$d"; \
|
||||||
|
if [ -f profile.out ]; then \
|
||||||
|
cat profile.out >> coverage.txt; \
|
||||||
|
rm profile.out; \
|
||||||
|
fi; \
|
||||||
|
done;
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean: ## Cleanup any build binaries or packages.
|
||||||
|
@echo "==> $@"
|
||||||
|
$(RM) -r $(BINDIR)
|
||||||
|
$(RM) -r $(BUILDDIR)
|
||||||
|
|
||||||
|
define buildpretty
|
||||||
|
mkdir -p $(BUILDDIR)/$(1)/$(2);
|
||||||
|
GOOS=$(1) GOARCH=$(2) CGO_ENABLED=0 GO111MODULE=on go build \
|
||||||
|
-o $(BUILDDIR)/$(1)/$(2)/$(NAME) \
|
||||||
|
${GO_LDFLAGS_STATIC} ./cmd/$(NAME);
|
||||||
|
md5sum $(BUILDDIR)/$(1)/$(2)/$(NAME) > $(BUILDDIR)/$(1)/$(2)/$(NAME).md5;
|
||||||
|
sha256sum $(BUILDDIR)/$(1)/$(2)/$(NAME) > $(BUILDDIR)/$(1)/$(2)/$(NAME).sha256;
|
||||||
|
endef
|
||||||
|
|
||||||
|
.PHONY: cross
|
||||||
|
cross: ## Builds the cross-compiled binaries, creating a clean directory structure (eg. GOOS/GOARCH/binary)
|
||||||
|
@echo "+ $@"
|
||||||
|
$(foreach GOOSARCH,$(GOOSARCHES), $(call buildpretty,$(subst /,,$(dir $(GOOSARCH))),$(notdir $(GOOSARCH))))
|
||||||
|
|
||||||
|
define buildrelease
|
||||||
|
GOOS=$(1) GOARCH=$(2) CGO_ENABLED=0 GO111MODULE=on go build \
|
||||||
|
-o $(BUILDDIR)/$(NAME)-$(1)-$(2) \
|
||||||
|
${GO_LDFLAGS_STATIC} ./cmd/$(NAME);
|
||||||
|
md5sum $(BUILDDIR)/$(NAME)-$(1)-$(2) > $(BUILDDIR)/$(NAME)-$(1)-$(2).md5;
|
||||||
|
sha256sum $(BUILDDIR)/$(NAME)-$(1)-$(2) > $(BUILDDIR)/$(NAME)-$(1)-$(2).sha256;
|
||||||
|
endef
|
||||||
|
|
||||||
|
.PHONY: release
|
||||||
|
release: ## Builds the cross-compiled binaries, naming them in such a way for release (eg. binary-GOOS-GOARCH)
|
||||||
|
@echo "+ $@"
|
||||||
|
$(foreach GOOSARCH,$(GOOSARCHES), $(call buildrelease,$(subst /,,$(dir $(GOOSARCH))),$(notdir $(GOOSARCH))))
|
||||||
|
|
||||||
|
.PHONY: help
|
||||||
|
help:
|
||||||
|
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
35
README.md
Normal file
35
README.md
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
<img height="200" src="./docs/logo.png" alt="logo" align="right" >
|
||||||
|
|
||||||
|
# Pomerium : identity-aware access proxy
|
||||||
|
[](https://travis-ci.org/pomerium/pomerium)
|
||||||
|
[](https://goreportcard.com/report/github.com/pomerium/pomerium)
|
||||||
|
[](https://github.com/pomerium/pomerium/blob/master/LICENSE)
|
||||||
|
|
||||||
|
Pomerium is a tool for managing secure access to internal applications and resources.
|
||||||
|
|
||||||
|
Use Pomerium to:
|
||||||
|
|
||||||
|
- provide a unified ingress gateway to internal corporate applications.
|
||||||
|
- enforce dynamic access policies based on context, identity, and device state.
|
||||||
|
- aggregate logging and telemetry data.
|
||||||
|
|
||||||
|
To learn more about zero-trust / BeyondCorp, check out [awesome-zero-trust].
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
For instructions on getting started with Pomerium, see our getting started docs.
|
||||||
|
|
||||||
|
## To start developing Pomerium
|
||||||
|
|
||||||
|
Assuming you have a working [Go environment].
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ go get -d github.com/pomerium/pomerium
|
||||||
|
$ cd $GOPATH/src/github.com/pomerium/pomerium
|
||||||
|
$ make
|
||||||
|
$ source ./env # see env.example
|
||||||
|
$ ./bin/pomerium -debug
|
||||||
|
```
|
||||||
|
|
||||||
|
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
|
||||||
|
[Go environment]: https://golang.org/doc/install
|
1
VERSION
Normal file
1
VERSION
Normal file
|
@ -0,0 +1 @@
|
||||||
|
v0.0.1
|
254
authenticate/authenticate.go
Normal file
254
authenticate/authenticate.go
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/envconfig"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/authenticate/providers"
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options permits the configuration of the authentication service
|
||||||
|
type Options struct {
|
||||||
|
// e.g.
|
||||||
|
Host string `envconfig:"HOST"`
|
||||||
|
//
|
||||||
|
ProxyClientID string `envconfig:"PROXY_CLIENT_ID"`
|
||||||
|
ProxyClientSecret string `envconfig:"PROXY_CLIENT_SECRET"`
|
||||||
|
|
||||||
|
// Coarse authorization based on user email domain
|
||||||
|
EmailDomains []string `envconfig:"SSO_EMAIL_DOMAIN"`
|
||||||
|
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
|
||||||
|
|
||||||
|
// Session/Cookie management
|
||||||
|
CookieName string
|
||||||
|
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||||
|
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||||
|
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE" default:"168h"`
|
||||||
|
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH" default:"1h"`
|
||||||
|
CookieSecure bool `envconfig:"COOKIE_SECURE" default:"true"`
|
||||||
|
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY" default:"true"`
|
||||||
|
|
||||||
|
AuthCodeSecret string `envconfig:"AUTH_CODE_SECRET"`
|
||||||
|
|
||||||
|
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL" default:"720h"`
|
||||||
|
|
||||||
|
// Authentication provider configuration vars
|
||||||
|
RedirectURL *url.URL `envconfig:"IDP_REDIRECT_URL" ` // e.g. auth.example.com/oauth/callback
|
||||||
|
ClientID string `envconfig:"IDP_CLIENT_ID"` // IdP ClientID
|
||||||
|
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"` // IdP Secret
|
||||||
|
Provider string `envconfig:"IDP_PROVIDER"` //Provider name e.g. "oidc","okta","google",etc
|
||||||
|
ProviderURL *url.URL `envconfig:"IDP_PROVIDER_URL"`
|
||||||
|
Scopes []string `envconfig:"IDP_SCOPE" default:"openid,email,profile"`
|
||||||
|
|
||||||
|
// todo(bdd) : can delete?`
|
||||||
|
ApprovalPrompt string `envconfig:"IDP_APPROVAL_PROMPT" default:"consent"`
|
||||||
|
RequestLogging bool `envconfig:"REQUEST_LOGGING" default:"true"`
|
||||||
|
RequestTimeout time.Duration `envconfig:"REQUEST_TIMEOUT" default:"2s"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultOptions = &Options{
|
||||||
|
EmailDomains: []string{"*"},
|
||||||
|
CookieName: "_pomerium_authenticate",
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: true,
|
||||||
|
CookieExpire: time.Duration(168) * time.Hour,
|
||||||
|
CookieRefresh: time.Duration(1) * time.Hour,
|
||||||
|
RequestTimeout: time.Duration(2) * time.Second,
|
||||||
|
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||||
|
|
||||||
|
ApprovalPrompt: "consent",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionsFromEnvConfig builds the authentication service's configuration
|
||||||
|
// options from provided environmental variables
|
||||||
|
func OptionsFromEnvConfig() (*Options, error) {
|
||||||
|
o := defaultOptions
|
||||||
|
if err := envconfig.Process("", o); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks to see if configuration values are valid for authentication service.
|
||||||
|
// The checks do not modify the internal state of the Option structure. Function returns
|
||||||
|
// on first error found.
|
||||||
|
func (o *Options) Validate() error {
|
||||||
|
if o.ProviderURL == nil {
|
||||||
|
return errors.New("missing setting: identity provider url")
|
||||||
|
}
|
||||||
|
if o.RedirectURL == nil {
|
||||||
|
return errors.New("missing setting: identity provider redirect url")
|
||||||
|
}
|
||||||
|
redirectPath := "/oauth2/callback"
|
||||||
|
if o.RedirectURL.Path != redirectPath {
|
||||||
|
return fmt.Errorf("setting redirect-url was %s path should be %s", o.RedirectURL.Path, redirectPath)
|
||||||
|
}
|
||||||
|
if o.ClientID == "" {
|
||||||
|
return errors.New("missing setting: client id")
|
||||||
|
}
|
||||||
|
if o.ClientSecret == "" {
|
||||||
|
return errors.New("missing setting: client secret")
|
||||||
|
}
|
||||||
|
if len(o.EmailDomains) == 0 {
|
||||||
|
return errors.New("missing setting email domain")
|
||||||
|
}
|
||||||
|
if len(o.ProxyRootDomains) == 0 {
|
||||||
|
return errors.New("missing setting: proxy root domain")
|
||||||
|
}
|
||||||
|
if o.ProxyClientID == "" {
|
||||||
|
return errors.New("missing setting: proxy client id")
|
||||||
|
}
|
||||||
|
if o.ProxyClientSecret == "" {
|
||||||
|
return errors.New("missing setting: proxy client secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("authenticate options: cookie secret invalid"+
|
||||||
|
"must be a base64-encoded, 256 bit key e.g. `head -c32 /dev/urandom | base64`"+
|
||||||
|
"got %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
validCookieSecretLength := false
|
||||||
|
for _, i := range []int{32, 64} {
|
||||||
|
if len(decodedCookieSecret) == i {
|
||||||
|
validCookieSecretLength = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validCookieSecretLength {
|
||||||
|
return fmt.Errorf("authenticate options: invalid cookie secret strength want 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.CookieRefresh >= o.CookieExpire {
|
||||||
|
return fmt.Errorf("cookie_refresh (%s) must be less than cookie_expire (%s)", o.CookieRefresh.String(), o.CookieExpire.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticator stores all the information associated with proxying the request.
|
||||||
|
type Authenticator struct {
|
||||||
|
Validator func(string) bool
|
||||||
|
|
||||||
|
EmailDomains []string
|
||||||
|
ProxyRootDomains []string
|
||||||
|
Host string
|
||||||
|
CookieSecure bool
|
||||||
|
|
||||||
|
ProxyClientID string
|
||||||
|
ProxyClientSecret string
|
||||||
|
|
||||||
|
SessionLifetimeTTL time.Duration
|
||||||
|
|
||||||
|
decodedCookieSecret []byte
|
||||||
|
templates *template.Template
|
||||||
|
// sesion related
|
||||||
|
csrfStore sessions.CSRFStore
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
cipher aead.Cipher
|
||||||
|
|
||||||
|
redirectURL *url.URL
|
||||||
|
provider providers.Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthenticator creates a Authenticator struct and applies the optional functions slice to the struct.
|
||||||
|
func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) (*Authenticator, error) {
|
||||||
|
if opts == nil {
|
||||||
|
return nil, errors.New("options cannot be nil")
|
||||||
|
}
|
||||||
|
if err := opts.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decodedAuthCodeSecret, err := base64.StdEncoding.DecodeString(opts.AuthCodeSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cipher, err := aead.NewMiscreantCipher([]byte(decodedAuthCodeSecret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
|
||||||
|
sessions.CreateMiscreantCookieCipher(decodedCookieSecret),
|
||||||
|
func(c *sessions.CookieStore) error {
|
||||||
|
c.CookieDomain = opts.CookieDomain
|
||||||
|
c.CookieHTTPOnly = opts.CookieHTTPOnly
|
||||||
|
c.CookieExpire = opts.CookieExpire
|
||||||
|
c.CookieSecure = opts.CookieSecure
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &Authenticator{
|
||||||
|
ProxyClientID: opts.ProxyClientID,
|
||||||
|
ProxyClientSecret: opts.ProxyClientSecret,
|
||||||
|
EmailDomains: opts.EmailDomains,
|
||||||
|
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
|
||||||
|
CookieSecure: opts.CookieSecure,
|
||||||
|
redirectURL: opts.RedirectURL,
|
||||||
|
templates: templates.New(),
|
||||||
|
csrfStore: cookieStore,
|
||||||
|
sessionStore: cookieStore,
|
||||||
|
cipher: cipher,
|
||||||
|
}
|
||||||
|
// p.ServeMux = p.Handler()
|
||||||
|
p.provider, err = newProvider(opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply the option functions
|
||||||
|
for _, optFunc := range optionFuncs {
|
||||||
|
err := optFunc(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProvider(opts *Options) (providers.Provider, error) {
|
||||||
|
pd := &providers.ProviderData{
|
||||||
|
RedirectURL: opts.RedirectURL,
|
||||||
|
ProviderName: opts.Provider,
|
||||||
|
ClientID: opts.ClientID,
|
||||||
|
ClientSecret: opts.ClientSecret,
|
||||||
|
ApprovalPrompt: opts.ApprovalPrompt,
|
||||||
|
SessionLifetimeTTL: opts.SessionLifetimeTTL,
|
||||||
|
ProviderURL: opts.ProviderURL,
|
||||||
|
Scopes: opts.Scopes,
|
||||||
|
}
|
||||||
|
np, err := providers.New(opts.Provider, pd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return providers.NewSingleFlightProvider(np), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func dotPrependDomains(d []string) []string {
|
||||||
|
for i := range d {
|
||||||
|
if d[i] != "" && !strings.HasPrefix(d[i], ".") {
|
||||||
|
d[i] = fmt.Sprintf(".%s", d[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
1097
authenticate/authenticate_test.go
Normal file
1097
authenticate/authenticate_test.go
Normal file
File diff suppressed because it is too large
Load diff
329
authenticate/circuit/breaker.go
Normal file
329
authenticate/circuit/breaker.go
Normal file
|
@ -0,0 +1,329 @@
|
||||||
|
// Package circuit implements the Circuit Breaker pattern.
|
||||||
|
// https://docs.microsoft.com/en-us/azure/architecture/patterns/circuit-breaker
|
||||||
|
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/benbjohnson/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State is a type that represents a state of Breaker.
|
||||||
|
type State int
|
||||||
|
|
||||||
|
// These constants are states of Breaker.
|
||||||
|
const (
|
||||||
|
StateClosed State = iota
|
||||||
|
StateHalfOpen
|
||||||
|
StateOpen
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// ShouldTripFunc is a function that takes in a Counts and returns true if the circuit breaker should be tripped.
|
||||||
|
ShouldTripFunc func(Counts) bool
|
||||||
|
// ShouldResetFunc is a function that takes in a Counts and returns true if the circuit breaker should be reset.
|
||||||
|
ShouldResetFunc func(Counts) bool
|
||||||
|
// BackoffDurationFunc is a function that takes in a Counts and returns the backoff duration
|
||||||
|
BackoffDurationFunc func(Counts) time.Duration
|
||||||
|
|
||||||
|
// StateChangeHook is a function that represents a state change.
|
||||||
|
StateChangeHook func(prev, to State)
|
||||||
|
// BackoffHook is a function that represents backoff.
|
||||||
|
BackoffHook func(duration time.Duration, reset time.Time)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultShouldTripFunc is a default ShouldTripFunc.
|
||||||
|
DefaultShouldTripFunc = func(counts Counts) bool {
|
||||||
|
// Trip into Open after three consecutive failures
|
||||||
|
return counts.ConsecutiveFailures >= 3
|
||||||
|
}
|
||||||
|
// DefaultShouldResetFunc is a default ShouldResetFunc.
|
||||||
|
DefaultShouldResetFunc = func(counts Counts) bool {
|
||||||
|
// Reset after three consecutive successes
|
||||||
|
return counts.ConsecutiveSuccesses >= 3
|
||||||
|
}
|
||||||
|
// DefaultBackoffDurationFunc is an exponential backoff function
|
||||||
|
DefaultBackoffDurationFunc = ExponentialBackoffDuration(time.Duration(100)*time.Second, time.Duration(500)*time.Millisecond)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrOpenState is returned when the b state is open
|
||||||
|
type ErrOpenState struct{}
|
||||||
|
|
||||||
|
func (e *ErrOpenState) Error() string { return "circuit breaker is open" }
|
||||||
|
|
||||||
|
// ExponentialBackoffDuration returns a function that uses exponential backoff and full jitter
|
||||||
|
func ExponentialBackoffDuration(maxBackoff, baseTimeout time.Duration) func(Counts) time.Duration {
|
||||||
|
return func(counts Counts) time.Duration {
|
||||||
|
// Full Jitter from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||||
|
// sleep = random_between(0, min(cap, base * 2 ** attempt))
|
||||||
|
backoff := math.Min(float64(maxBackoff), float64(baseTimeout)*math.Exp2(float64(counts.ConsecutiveFailures)))
|
||||||
|
jittered := rand.Float64() * backoff
|
||||||
|
return time.Duration(jittered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements stringer interface.
|
||||||
|
func (s State) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateClosed:
|
||||||
|
return "closed"
|
||||||
|
case StateHalfOpen:
|
||||||
|
return "half-open"
|
||||||
|
case StateOpen:
|
||||||
|
return "open"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown state: %d", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counts holds the numbers of requests and their successes/failures.
|
||||||
|
type Counts struct {
|
||||||
|
CurrentRequests int
|
||||||
|
ConsecutiveSuccesses int
|
||||||
|
ConsecutiveFailures int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counts) onRequest() {
|
||||||
|
c.CurrentRequests++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counts) afterRequest() {
|
||||||
|
c.CurrentRequests--
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counts) onSuccess() {
|
||||||
|
c.ConsecutiveSuccesses++
|
||||||
|
c.ConsecutiveFailures = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counts) onFailure() {
|
||||||
|
c.ConsecutiveFailures++
|
||||||
|
c.ConsecutiveSuccesses = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counts) clear() {
|
||||||
|
c.ConsecutiveSuccesses = 0
|
||||||
|
c.ConsecutiveFailures = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options configures Breaker:
|
||||||
|
//
|
||||||
|
// HalfOpenConcurrentRequests specifies how many concurrent requests to allow while
|
||||||
|
// the circuit is in the half-open state
|
||||||
|
//
|
||||||
|
// ShouldTripFunc specifies when the circuit should trip from the closed state to
|
||||||
|
// the open state. It takes a Counts struct and returns a bool.
|
||||||
|
//
|
||||||
|
// ShouldResetFunc specifies when the circuit should be reset from the half-open state
|
||||||
|
// to the closed state and allow all requests. It takes a Counts struct and returns a bool.
|
||||||
|
//
|
||||||
|
// BackoffDurationFunc specifies how long to set the backoff duration. It takes a
|
||||||
|
// counts struct and returns a time.Duration
|
||||||
|
//
|
||||||
|
// OnStateChange is called whenever the state of the Breaker changes.
|
||||||
|
//
|
||||||
|
// OnBackoff is called whenever a backoff is set with the backoff duration and reset time
|
||||||
|
//
|
||||||
|
// TestClock is used to mock the clock during tests
|
||||||
|
type Options struct {
|
||||||
|
HalfOpenConcurrentRequests int
|
||||||
|
|
||||||
|
ShouldTripFunc ShouldTripFunc
|
||||||
|
ShouldResetFunc ShouldResetFunc
|
||||||
|
BackoffDurationFunc BackoffDurationFunc
|
||||||
|
|
||||||
|
// hooks
|
||||||
|
OnStateChange StateChangeHook
|
||||||
|
OnBackoff BackoffHook
|
||||||
|
|
||||||
|
// used in tests
|
||||||
|
TestClock clock.Clock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Breaker is a state machine to prevent sending requests that are likely to fail.
|
||||||
|
type Breaker struct {
|
||||||
|
halfOpenRequests int
|
||||||
|
|
||||||
|
shouldTripFunc ShouldTripFunc
|
||||||
|
shouldResetFunc ShouldResetFunc
|
||||||
|
backoffDurationFunc BackoffDurationFunc
|
||||||
|
|
||||||
|
// hooks
|
||||||
|
onStateChange StateChangeHook
|
||||||
|
onBackoff BackoffHook
|
||||||
|
|
||||||
|
// used primarly for mocking tests
|
||||||
|
clock clock.Clock
|
||||||
|
|
||||||
|
mutex sync.Mutex
|
||||||
|
state State
|
||||||
|
counts Counts
|
||||||
|
backoffExpires time.Time
|
||||||
|
generation int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBreaker returns a new Breaker configured with the given Settings.
|
||||||
|
func NewBreaker(opts *Options) *Breaker {
|
||||||
|
b := new(Breaker)
|
||||||
|
|
||||||
|
if opts == nil {
|
||||||
|
opts = &Options{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set hooks
|
||||||
|
b.onStateChange = opts.OnStateChange
|
||||||
|
b.onBackoff = opts.OnBackoff
|
||||||
|
|
||||||
|
b.halfOpenRequests = 1
|
||||||
|
if opts.HalfOpenConcurrentRequests > 0 {
|
||||||
|
b.halfOpenRequests = opts.HalfOpenConcurrentRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
b.backoffDurationFunc = DefaultBackoffDurationFunc
|
||||||
|
if opts.BackoffDurationFunc != nil {
|
||||||
|
b.backoffDurationFunc = opts.BackoffDurationFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
b.shouldTripFunc = DefaultShouldTripFunc
|
||||||
|
if opts.ShouldTripFunc != nil {
|
||||||
|
b.shouldTripFunc = opts.ShouldTripFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
b.shouldResetFunc = DefaultShouldResetFunc
|
||||||
|
if opts.ShouldResetFunc != nil {
|
||||||
|
b.shouldResetFunc = opts.ShouldResetFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
b.clock = clock.New()
|
||||||
|
if opts.TestClock != nil {
|
||||||
|
b.clock = opts.TestClock
|
||||||
|
}
|
||||||
|
|
||||||
|
b.setState(StateClosed)
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call runs the given function if the Breaker allows the call.
|
||||||
|
// Call returns an error instantly if the Breaker rejects the request.
|
||||||
|
// Otherwise, Call returns the result of the request.
|
||||||
|
func (b *Breaker) Call(f func() (interface{}, error)) (interface{}, error) {
|
||||||
|
generation, err := b.beforeRequest()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := f()
|
||||||
|
b.afterRequest(err == nil, generation)
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) beforeRequest() (int, error) {
|
||||||
|
b.mutex.Lock()
|
||||||
|
defer b.mutex.Unlock()
|
||||||
|
|
||||||
|
state, generation := b.currentState()
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case StateOpen:
|
||||||
|
return generation, &ErrOpenState{}
|
||||||
|
case StateHalfOpen:
|
||||||
|
if b.counts.CurrentRequests >= b.halfOpenRequests {
|
||||||
|
return generation, &ErrOpenState{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.counts.onRequest()
|
||||||
|
return generation, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) afterRequest(success bool, prevGeneration int) {
|
||||||
|
b.mutex.Lock()
|
||||||
|
defer b.mutex.Unlock()
|
||||||
|
|
||||||
|
b.counts.afterRequest()
|
||||||
|
|
||||||
|
state, generation := b.currentState()
|
||||||
|
if prevGeneration != generation {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if success {
|
||||||
|
b.onSuccess(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.onFailure(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) onSuccess(state State) {
|
||||||
|
b.counts.onSuccess()
|
||||||
|
switch state {
|
||||||
|
case StateHalfOpen:
|
||||||
|
if b.shouldResetFunc(b.counts) {
|
||||||
|
b.setState(StateClosed)
|
||||||
|
b.counts.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) onFailure(state State) {
|
||||||
|
b.counts.onFailure()
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case StateClosed:
|
||||||
|
if b.shouldTripFunc(b.counts) {
|
||||||
|
b.setState(StateOpen)
|
||||||
|
b.counts.clear()
|
||||||
|
b.setBackoff()
|
||||||
|
}
|
||||||
|
case StateOpen:
|
||||||
|
b.setBackoff()
|
||||||
|
case StateHalfOpen:
|
||||||
|
b.setState(StateOpen)
|
||||||
|
b.setBackoff()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) setBackoff() {
|
||||||
|
backoffDuration := b.backoffDurationFunc(b.counts)
|
||||||
|
backoffExpires := b.clock.Now().Add(backoffDuration)
|
||||||
|
b.backoffExpires = backoffExpires
|
||||||
|
if b.onBackoff != nil {
|
||||||
|
b.onBackoff(backoffDuration, backoffExpires)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) currentState() (State, int) {
|
||||||
|
switch b.state {
|
||||||
|
case StateOpen:
|
||||||
|
if b.clock.Now().After(b.backoffExpires) {
|
||||||
|
b.setState(StateHalfOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.state, b.generation
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) newGeneration() {
|
||||||
|
b.generation++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Breaker) setState(state State) {
|
||||||
|
if b.state == state {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.newGeneration()
|
||||||
|
|
||||||
|
prev := b.state
|
||||||
|
b.state = state
|
||||||
|
|
||||||
|
if b.onStateChange != nil {
|
||||||
|
b.onStateChange(prev, state)
|
||||||
|
}
|
||||||
|
}
|
187
authenticate/circuit/breaker_test.go
Normal file
187
authenticate/circuit/breaker_test.go
Normal file
|
@ -0,0 +1,187 @@
|
||||||
|
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/benbjohnson/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errFailed = errors.New("failed")
|
||||||
|
|
||||||
|
func fail() (interface{}, error) {
|
||||||
|
return nil, errFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func succeed() (interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker(t *testing.T) {
|
||||||
|
mock := clock.NewMock()
|
||||||
|
threshold := 3
|
||||||
|
timeout := time.Duration(2) * time.Second
|
||||||
|
trip := func(c Counts) bool { return c.ConsecutiveFailures > threshold }
|
||||||
|
reset := func(c Counts) bool { return c.ConsecutiveSuccesses > threshold }
|
||||||
|
backoff := func(c Counts) time.Duration { return timeout }
|
||||||
|
stateChange := func(p, c State) { t.Logf("state change from %s to %s\n", p, c) }
|
||||||
|
cb := NewBreaker(&Options{
|
||||||
|
TestClock: mock,
|
||||||
|
ShouldTripFunc: trip,
|
||||||
|
ShouldResetFunc: reset,
|
||||||
|
BackoffDurationFunc: backoff,
|
||||||
|
OnStateChange: stateChange,
|
||||||
|
})
|
||||||
|
state, _ := cb.currentState()
|
||||||
|
if state != StateClosed {
|
||||||
|
t.Fatalf("expected state to start %s, got %s", StateClosed, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i <= threshold; i++ {
|
||||||
|
_, err := cb.Call(fail)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected to error, got nil")
|
||||||
|
}
|
||||||
|
state, _ := cb.currentState()
|
||||||
|
t.Logf("iteration %#v", i)
|
||||||
|
if i == threshold {
|
||||||
|
// we expect this to be the case to trip the circuit
|
||||||
|
if state != StateOpen {
|
||||||
|
t.Fatalf("expected state to be %s, got %s", StateOpen, state)
|
||||||
|
}
|
||||||
|
} else if state != StateClosed {
|
||||||
|
// this is a normal failure case
|
||||||
|
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := cb.Call(fail)
|
||||||
|
switch err.(type) {
|
||||||
|
case *ErrOpenState:
|
||||||
|
// this is the expected case
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
t.Errorf("%#v", cb.counts)
|
||||||
|
t.Fatalf("expected to get open state failure, got %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// we advance time by the timeout and a hair
|
||||||
|
mock.Add(timeout + time.Duration(1)*time.Millisecond)
|
||||||
|
state, _ = cb.currentState()
|
||||||
|
if state != StateHalfOpen {
|
||||||
|
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i <= threshold; i++ {
|
||||||
|
_, err := cb.Call(succeed)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected to get no error, got %s", err)
|
||||||
|
}
|
||||||
|
state, _ := cb.currentState()
|
||||||
|
t.Logf("iteration %#v", i)
|
||||||
|
if i == threshold {
|
||||||
|
// we expect this to be the case that ressets the circuit
|
||||||
|
if state != StateClosed {
|
||||||
|
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||||
|
}
|
||||||
|
} else if state != StateHalfOpen {
|
||||||
|
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state, _ = cb.currentState()
|
||||||
|
if state != StateClosed {
|
||||||
|
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExponentialBackOffFunc(t *testing.T) {
|
||||||
|
baseTimeout := time.Duration(1) * time.Millisecond
|
||||||
|
// Note Expected is an upper range case
|
||||||
|
cases := []struct {
|
||||||
|
FailureCount int
|
||||||
|
Expected time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
FailureCount: 0,
|
||||||
|
Expected: time.Duration(1) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 1,
|
||||||
|
Expected: time.Duration(2) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 2,
|
||||||
|
Expected: time.Duration(4) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 3,
|
||||||
|
Expected: time.Duration(8) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 4,
|
||||||
|
Expected: time.Duration(16) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 5,
|
||||||
|
Expected: time.Duration(32) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 6,
|
||||||
|
Expected: time.Duration(64) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 7,
|
||||||
|
Expected: time.Duration(128) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 8,
|
||||||
|
Expected: time.Duration(256) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 9,
|
||||||
|
Expected: time.Duration(512) * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
FailureCount: 10,
|
||||||
|
Expected: time.Duration(1024) * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
f := ExponentialBackoffDuration(time.Duration(1)*time.Hour, baseTimeout)
|
||||||
|
for _, tc := range cases {
|
||||||
|
got := f(Counts{ConsecutiveFailures: tc.FailureCount})
|
||||||
|
t.Logf("got backoff %#v", got)
|
||||||
|
if got > tc.Expected {
|
||||||
|
t.Errorf("got %#v but expected less than %#v", got, tc.Expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreakerClosedParallel(t *testing.T) {
|
||||||
|
cb := NewBreaker(nil)
|
||||||
|
numReqs := 10000
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
routine := func(wg *sync.WaitGroup) {
|
||||||
|
for i := 0; i < numReqs; i++ {
|
||||||
|
cb.Call(succeed)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
numRoutines := 10
|
||||||
|
for i := 0; i < numRoutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go routine(wg)
|
||||||
|
}
|
||||||
|
|
||||||
|
total := numReqs * numRoutines
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if cb.counts.ConsecutiveSuccesses != total {
|
||||||
|
t.Fatalf("expected to get total requests %d, got %d", total, cb.counts.ConsecutiveSuccesses)
|
||||||
|
}
|
||||||
|
}
|
629
authenticate/handlers.go
Normal file
629
authenticate/handlers.go
Normal file
|
@ -0,0 +1,629 @@
|
||||||
|
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
m "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var securityHeaders = map[string]string{
|
||||||
|
"Strict-Transport-Security": "max-age=31536000",
|
||||||
|
"X-Frame-Options": "DENY",
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"X-XSS-Protection": "1; mode=block",
|
||||||
|
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||||
|
"Referrer-Policy": "Same-origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns the Http.Handlers for authentication, callback, and refresh
|
||||||
|
func (p *Authenticator) Handler() http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
// we setup global endpoints that should respond to any hostname
|
||||||
|
mux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
||||||
|
|
||||||
|
serviceMux := http.NewServeMux()
|
||||||
|
// standard rest and healthcheck endpoints
|
||||||
|
serviceMux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
||||||
|
serviceMux.HandleFunc("/robots.txt", m.WithMethods(p.RobotsTxt, "GET"))
|
||||||
|
// Identity Provider (IdP) endpoints and callbacks
|
||||||
|
serviceMux.HandleFunc("/start", m.WithMethods(p.OAuthStart, "GET"))
|
||||||
|
serviceMux.HandleFunc("/oauth2/callback", m.WithMethods(p.OAuthCallback, "GET"))
|
||||||
|
// authenticator-server endpoints, todo(bdd): make gRPC
|
||||||
|
serviceMux.HandleFunc("/sign_in", m.WithMethods(m.ValidateClientID(p.validateSignature(p.SignIn), p.ProxyClientID), "GET"))
|
||||||
|
serviceMux.HandleFunc("/sign_out", m.WithMethods(p.validateSignature(p.SignOut), "GET", "POST"))
|
||||||
|
serviceMux.HandleFunc("/profile", m.WithMethods(p.validateExisting(p.GetProfile), "GET"))
|
||||||
|
serviceMux.HandleFunc("/validate", m.WithMethods(p.validateExisting(p.ValidateToken), "GET"))
|
||||||
|
serviceMux.HandleFunc("/redeem", m.WithMethods(p.validateExisting(p.Redeem), "POST"))
|
||||||
|
serviceMux.HandleFunc("/refresh", m.WithMethods(p.validateExisting(p.Refresh), "POST"))
|
||||||
|
|
||||||
|
// NOTE: we have to include trailing slash for the router to match the host header
|
||||||
|
host := p.Host
|
||||||
|
if !strings.HasSuffix(host, "/") {
|
||||||
|
host = fmt.Sprintf("%s/", host)
|
||||||
|
}
|
||||||
|
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
|
||||||
|
|
||||||
|
return m.SetHeaders(mux, securityHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateSignature wraps a common collection of middlewares to validate signatures
|
||||||
|
func (p *Authenticator) validateSignature(f http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return validateRedirectURI(validateSignature(f, p.ProxyClientSecret), p.ProxyRootDomains)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateSignature wraps a common collection of middlewares to validate
|
||||||
|
// a (presumably) existing user session
|
||||||
|
func (p *Authenticator) validateExisting(f http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return m.ValidateClientID(m.ValidateClientSecret(f, p.ProxyClientSecret), p.ProxyClientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RobotsTxt handles the /robots.txt route.
|
||||||
|
func (p *Authenticator) RobotsTxt(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingPage handles the /ping route
|
||||||
|
func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(rw, "OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignInPage directs the user to the sign in page
|
||||||
|
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||||
|
requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
redirectURL := p.redirectURL.ResolveReference(req.URL)
|
||||||
|
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||||
|
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||||
|
t := struct {
|
||||||
|
ProviderName string
|
||||||
|
EmailDomains []string
|
||||||
|
Redirect string
|
||||||
|
Destination string
|
||||||
|
Version string
|
||||||
|
}{
|
||||||
|
ProviderName: p.provider.Data().ProviderName,
|
||||||
|
EmailDomains: p.EmailDomains,
|
||||||
|
Redirect: redirectURL.String(),
|
||||||
|
Destination: destinationURL.Host,
|
||||||
|
Version: version.FullVersion(),
|
||||||
|
}
|
||||||
|
requestLog.Info().
|
||||||
|
Str("ProviderName", p.provider.Data().ProviderName).
|
||||||
|
Str("Redirect", redirectURL.String()).
|
||||||
|
Str("Destination", destinationURL.Host).
|
||||||
|
Str("EmailDomains", strings.Join(p.EmailDomains, ", ")).
|
||||||
|
Msg("authenticate.SignInPage")
|
||||||
|
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
|
||||||
|
requestLog := log.WithRequest(req, "authenticate.authenticate")
|
||||||
|
|
||||||
|
session, err := p.sessionStore.LoadSession(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate.authenticate")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure sessions lifetime has not expired
|
||||||
|
if session.LifetimePeriodExpired() {
|
||||||
|
requestLog.Warn().Msg("lifetime expired")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, sessions.ErrLifetimeExpired
|
||||||
|
}
|
||||||
|
// check if session refresh period is up
|
||||||
|
if session.RefreshPeriodExpired() {
|
||||||
|
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("failed to refresh session")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
requestLog.Error().Msg("user unauthorized after refresh")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
// update refresh'd session in cookie
|
||||||
|
err = p.sessionStore.SaveSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
// We refreshed the session successfully, but failed to save it.
|
||||||
|
// This could be from failing to encode the session properly.
|
||||||
|
// But, we clear the session cookie and reject the request
|
||||||
|
requestLog.Error().Err(err).Msg("could not save refreshed session")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The session has not exceeded it's lifetime or requires refresh
|
||||||
|
ok := p.provider.ValidateSessionState(session)
|
||||||
|
if !ok {
|
||||||
|
requestLog.Error().Msg("invalid session state")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
err = p.sessionStore.SaveSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("failed to save valid session")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.Validator(session.Email) {
|
||||||
|
requestLog.Error().Msg("invalid email user")
|
||||||
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
||||||
|
// and if the user is not authenticated, it renders a sign in page.
|
||||||
|
func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
|
||||||
|
// page.
|
||||||
|
//
|
||||||
|
// If the user is authenticated, we redirect back to the proxy application
|
||||||
|
// at the `redirect_uri`, with a temporary token.
|
||||||
|
//
|
||||||
|
// TODO: It is possible for a user to visit this page without a redirect destination.
|
||||||
|
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
|
||||||
|
|
||||||
|
session, err := p.authenticate(rw, req)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
// User is authenticated, redirect back to the proxy application
|
||||||
|
// with the necessary state
|
||||||
|
p.ProxyOAuthRedirect(rw, req, session)
|
||||||
|
case http.ErrNoCookie:
|
||||||
|
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
||||||
|
p.SignInPage(rw, req, http.StatusOK)
|
||||||
|
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||||
|
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
p.SignInPage(rw, req, http.StatusOK)
|
||||||
|
default:
|
||||||
|
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyOAuthRedirect redirects the user back to sso proxy's redirection endpoint.
|
||||||
|
func (p *Authenticator) ProxyOAuthRedirect(rw http.ResponseWriter, req *http.Request, session *sessions.SessionState) {
|
||||||
|
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
||||||
|
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
||||||
|
//
|
||||||
|
// We redirect the user back to the proxy application's redirection endpoint; in the
|
||||||
|
// sso proxy, this is the `/oauth/callback` endpoint.
|
||||||
|
//
|
||||||
|
// We must provide the proxy with a temporary authorization code via the `code` parameter,
|
||||||
|
// which they can use to redeem an access token for subsequent API calls.
|
||||||
|
//
|
||||||
|
// We must also include the original `state` parameter received from the proxy application.
|
||||||
|
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := req.Form.Get("state")
|
||||||
|
if state == "" {
|
||||||
|
httputil.ErrorResponse(rw, req, "no state parameter supplied", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
if redirectURI == "" {
|
||||||
|
httputil.ErrorResponse(rw, req, "no redirect_uri parameter supplied", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL, err := url.Parse(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, "malformed redirect_uri parameter passed", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted, err := sessions.MarshalSession(session, p.cipher)
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(rw, req, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
|
||||||
|
u, _ := url.Parse(redirectURL.String())
|
||||||
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
params.Set("code", authCode)
|
||||||
|
params.Set("state", state)
|
||||||
|
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
if u.Scheme == "" {
|
||||||
|
u.Scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOut signs the user out.
|
||||||
|
func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
if req.Method == "GET" {
|
||||||
|
p.SignOutPage(rw, req, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := p.sessionStore.LoadSession(req)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
break
|
||||||
|
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
|
||||||
|
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// a different error, clear the session cookie and redirect
|
||||||
|
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.provider.Revoke(session)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
|
||||||
|
p.SignOutPage(rw, req, "An error occurred during sign out. Please try again.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOutPage renders a sign out page with a message
|
||||||
|
func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, message string) {
|
||||||
|
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
session, err := p.sessionStore.LoadSession(req)
|
||||||
|
if err != nil {
|
||||||
|
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
signature := req.Form.Get("sig")
|
||||||
|
timestamp := req.Form.Get("ts")
|
||||||
|
destinationURL, _ := url.Parse(redirectURI)
|
||||||
|
|
||||||
|
// An error message indicates that an internal server error occurred
|
||||||
|
if message != "" {
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
t := struct {
|
||||||
|
Redirect string
|
||||||
|
Signature string
|
||||||
|
Timestamp string
|
||||||
|
Message string
|
||||||
|
Destination string
|
||||||
|
Email string
|
||||||
|
Version string
|
||||||
|
}{
|
||||||
|
Redirect: redirectURI,
|
||||||
|
Signature: signature,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
Message: message,
|
||||||
|
Destination: destinationURL.Host,
|
||||||
|
Email: session.Email,
|
||||||
|
Version: version.FullVersion(),
|
||||||
|
}
|
||||||
|
p.templates.ExecuteTemplate(rw, "sign_out.html", t)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
||||||
|
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
||||||
|
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
|
nonce := fmt.Sprintf("%x", aead.GenerateKey())
|
||||||
|
p.csrfStore.SetCSRF(rw, req, nonce)
|
||||||
|
|
||||||
|
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
|
||||||
|
if err != nil || !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||||
|
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
||||||
|
ts := authRedirectURL.Query().Get("ts")
|
||||||
|
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.ProxyClientSecret) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
||||||
|
|
||||||
|
signInURL := p.provider.GetSignInURL(state)
|
||||||
|
|
||||||
|
http.Redirect(rw, req, signInURL, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, error) {
|
||||||
|
session, err := p.provider.Redeem(code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if session.Email == "" {
|
||||||
|
return nil, fmt.Errorf("no email included in session")
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||||
|
func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Request) (string, error) {
|
||||||
|
requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
|
||||||
|
// finish the oauth cycle
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||||
|
}
|
||||||
|
errorString := req.Form.Get("error")
|
||||||
|
if errorString != "" {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
|
||||||
|
}
|
||||||
|
code := req.Form.Get("code")
|
||||||
|
if code == "" {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := p.redeemCode(req.Host, code)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("error redeeming authentication code")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, err := base64.URLEncoding.DecodeString(req.Form.Get("state"))
|
||||||
|
if err != nil {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||||
|
}
|
||||||
|
s := strings.SplitN(string(bytes), ":", 2)
|
||||||
|
if len(s) != 2 {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||||
|
}
|
||||||
|
nonce := s[0]
|
||||||
|
redirect := s[1]
|
||||||
|
c, err := p.csrfStore.GetCSRF(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
||||||
|
}
|
||||||
|
p.csrfStore.ClearCSRF(rw, req)
|
||||||
|
if c.Value != nonce {
|
||||||
|
requestLog.Error().Err(err).Msg("csrf token mismatch")
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validRedirectURI(redirect, p.ProxyRootDomains) {
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cookie, or deny: The authenticator validates the session email and group
|
||||||
|
// - for p.Validator see validator.go#newValidatorImpl for more info
|
||||||
|
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
||||||
|
if !p.Validator(session.Email) {
|
||||||
|
requestLog.Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
||||||
|
}
|
||||||
|
requestLog.Info().Str("email", session.Email).Msg("authentication complete")
|
||||||
|
err = p.sessionStore.SaveSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("internal error")
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
|
||||||
|
}
|
||||||
|
return redirect, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthCallback handles the callback from the provider, and returns an error response if there is an error.
|
||||||
|
// If there is no error it will redirect to the redirect url.
|
||||||
|
func (p *Authenticator) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
redirect, err := p.getOAuthCallback(rw, req)
|
||||||
|
switch h := err.(type) {
|
||||||
|
case nil:
|
||||||
|
break
|
||||||
|
case httputil.HTTPError:
|
||||||
|
httputil.ErrorResponse(rw, req, h.Message, h.Code)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
httputil.ErrorResponse(rw, req, "Internal Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(rw, req, redirect, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem has a signed access token, and provides the user information associated with the access token.
|
||||||
|
func (p *Authenticator) Redeem(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
||||||
|
// expiration, and email.
|
||||||
|
requestLog := log.WithRequest(req, "authenticate.Redeem")
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := sessions.UnmarshalSession(req.Form.Get("code"), p.cipher)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
|
||||||
|
http.Error(rw, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
|
||||||
|
http.Error(rw, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
||||||
|
requestLog.Error().Msg("expired session")
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
http.Error(rw, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}{
|
||||||
|
AccessToken: session.AccessToken,
|
||||||
|
RefreshToken: session.RefreshToken,
|
||||||
|
IDToken: session.IDToken,
|
||||||
|
ExpiresIn: int64(session.RefreshDeadline.Sub(time.Now()).Seconds()),
|
||||||
|
Email: session.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.Header().Set("GAP-Auth", session.Email)
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
rw.Write(jsonBytes)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh takes a refresh token and returns a new access token
|
||||||
|
func (p *Authenticator) Refresh(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := req.Form.Get("refresh_token")
|
||||||
|
if refreshToken == "" {
|
||||||
|
http.Error(rw, "Bad Request: No Refresh Token", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
ExpiresIn: int64(expiresIn.Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusCreated)
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
rw.Write(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProfile gets a list of groups of which a user is a member.
|
||||||
|
func (p *Authenticator) GetProfile(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// The sso proxy sends the user's email to this endpoint to get a list of Google groups that
|
||||||
|
// the email is a member of. The proxy will compare these groups to the list of allowed
|
||||||
|
// groups for the upstream service the user is trying to access.
|
||||||
|
|
||||||
|
email := req.FormValue("email")
|
||||||
|
if email == "" {
|
||||||
|
http.Error(rw, "no email address included", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupsFormValue := req.FormValue("groups")
|
||||||
|
// allowedGroups := []string{}
|
||||||
|
// if groupsFormValue != "" {
|
||||||
|
// allowedGroups = strings.Split(groupsFormValue, ",")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Error().Err(err).Msg("authenticate.GetProfile : error retrieving groups")
|
||||||
|
// httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
|
||||||
|
response := struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
}{
|
||||||
|
Email: email,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, fmt.Sprintf("error marshaling response: %s", err.Error()), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.Header().Set("GAP-Auth", email)
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
rw.Write(jsonBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken validates the X-Access-Token from the header and returns an error response
|
||||||
|
// if it's invalid
|
||||||
|
func (p *Authenticator) ValidateToken(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
accessToken := req.Header.Get("X-Access-Token")
|
||||||
|
idToken := req.Header.Get("X-Id-Token")
|
||||||
|
|
||||||
|
if accessToken == "" {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if idToken == "" {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := p.provider.ValidateSessionState(&sessions.SessionState{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
IDToken: idToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
rw.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
98
authenticate/middleware.go
Normal file
98
authenticate/middleware.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||||
|
// the url's domain is one in the list of proxy root domains.
|
||||||
|
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||||
|
redirectURL, err := url.Parse(uri)
|
||||||
|
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, domain := range rootDomains {
|
||||||
|
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSignature(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
sigVal := req.Form.Get("sig")
|
||||||
|
timestamp := req.Form.Get("ts")
|
||||||
|
if !validSignature(redirectURI, sigVal, timestamp, proxyClientSecret) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||||
|
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err := url.Parse(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
i, err := strconv.ParseInt(timestamp, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tm := time.Unix(i, 0)
|
||||||
|
ttl := 5 * time.Minute
|
||||||
|
if time.Now().Sub(tm) > ttl {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||||
|
return hmac.Equal(requestSig, localSig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||||
|
h := hmac.New(sha256.New, []byte(secret))
|
||||||
|
h.Write([]byte(rawRedirect))
|
||||||
|
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
100
authenticate/providers/google.go
Normal file
100
authenticate/providers/google.go
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
oidc "github.com/pomerium/go-oidc"
|
||||||
|
"github.com/pomerium/pomerium/authenticate/circuit"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GoogleProvider is an implementation of the Provider interface.
|
||||||
|
type GoogleProvider struct {
|
||||||
|
*ProviderData
|
||||||
|
cb *circuit.Breaker
|
||||||
|
// non-standard oidc fields
|
||||||
|
RevokeURL *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
|
||||||
|
func NewGoogleProvider(p *ProviderData) (*GoogleProvider, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||||
|
p.oauth = &oauth2.Config{
|
||||||
|
ClientID: p.ClientID,
|
||||||
|
ClientSecret: p.ClientSecret,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: p.RedirectURL.String(),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
|
||||||
|
googleProvider := &GoogleProvider{
|
||||||
|
ProviderData: p,
|
||||||
|
}
|
||||||
|
// google supports a revokation endpoint
|
||||||
|
var claims struct {
|
||||||
|
RevokeURL string `json:"revocation_endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := provider.Claims(&claims); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
googleProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
googleProvider.cb = circuit.NewBreaker(&circuit.Options{
|
||||||
|
HalfOpenConcurrentRequests: 2,
|
||||||
|
OnStateChange: googleProvider.cbStateChange,
|
||||||
|
OnBackoff: googleProvider.cbBackoff,
|
||||||
|
ShouldTripFunc: func(c circuit.Counts) bool { return c.ConsecutiveFailures >= 3 },
|
||||||
|
ShouldResetFunc: func(c circuit.Counts) bool { return c.ConsecutiveSuccesses >= 6 },
|
||||||
|
BackoffDurationFunc: circuit.ExponentialBackoffDuration(
|
||||||
|
time.Duration(200)*time.Second,
|
||||||
|
time.Duration(500)*time.Millisecond),
|
||||||
|
})
|
||||||
|
|
||||||
|
return googleProvider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GoogleProvider) cbBackoff(duration time.Duration, reset time.Time) {
|
||||||
|
log.Info().Dur("duration", duration).Msg("authenticate/providers/google.cbBackoff")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GoogleProvider) cbStateChange(from, to circuit.State) {
|
||||||
|
log.Info().Str("from", from.String()).Str("to", to.String()).Msg("authenticate/providers/google.cbStateChange")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke revokes the access token a given session state.
|
||||||
|
//
|
||||||
|
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
|
||||||
|
// https://github.com/googleapis/google-api-dotnet-client/issues/1285
|
||||||
|
func (p *GoogleProvider) Revoke(s *sessions.SessionState) error {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("token", s.AccessToken)
|
||||||
|
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||||
|
if err != nil && err != httputil.ErrTokenRevoked {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||||
|
// Google requires access type offline
|
||||||
|
func (p *GoogleProvider) GetSignInURL(state string) string {
|
||||||
|
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||||
|
}
|
32
authenticate/providers/oidc.go
Normal file
32
authenticate/providers/oidc.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
oidc "github.com/pomerium/go-oidc"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OIDCProvider provides a standard, OpenID Connect implementation
|
||||||
|
// of an authorization identity provider.
|
||||||
|
type OIDCProvider struct {
|
||||||
|
*ProviderData
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOIDCProvider creates a new instance of an OpenID Connect provider.
|
||||||
|
func NewOIDCProvider(p *ProviderData) (*OIDCProvider, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||||
|
p.oauth = &oauth2.Config{
|
||||||
|
ClientID: p.ClientID,
|
||||||
|
ClientSecret: p.ClientSecret,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: p.RedirectURL.String(),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
return &OIDCProvider{ProviderData: p}, nil
|
||||||
|
}
|
69
authenticate/providers/okta.go
Normal file
69
authenticate/providers/okta.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
oidc "github.com/pomerium/go-oidc"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OktaProvider provides a standard, OpenID Connect implementation
|
||||||
|
// of an authorization identity provider.
|
||||||
|
type OktaProvider struct {
|
||||||
|
*ProviderData
|
||||||
|
|
||||||
|
// non-standard oidc fields
|
||||||
|
RevokeURL *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOktaProvider creates a new instance of an OpenID Connect provider.
|
||||||
|
func NewOktaProvider(p *ProviderData) (*OktaProvider, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
provider, err := oidc.NewProvider(ctx, "https://dev-108295.oktapreview.com/oauth2/default")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.verifier = provider.Verifier(&oidc.Config{ClientID: p.ClientID})
|
||||||
|
p.oauth = &oauth2.Config{
|
||||||
|
ClientID: p.ClientID,
|
||||||
|
ClientSecret: p.ClientSecret,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: p.RedirectURL.String(),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
oktaProvider := OktaProvider{ProviderData: p}
|
||||||
|
|
||||||
|
// okta supports a revokation endpoint
|
||||||
|
var claims struct {
|
||||||
|
RevokeURL string `json:"revocation_endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := provider.Claims(&claims); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &oktaProvider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke revokes the access token a given session state.
|
||||||
|
// https://developer.okta.com/docs/api/resources/oidc#revoke
|
||||||
|
func (p *OktaProvider) Revoke(s *sessions.SessionState) error {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
params.Add("client_secret", p.ClientSecret)
|
||||||
|
params.Add("token", s.IDToken)
|
||||||
|
params.Add("token_type_hint", "refresh_token")
|
||||||
|
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), params, nil)
|
||||||
|
if err != nil && err != httputil.ErrTokenRevoked {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
256
authenticate/providers/providers.go
Normal file
256
authenticate/providers/providers.go
Normal file
|
@ -0,0 +1,256 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
oidc "github.com/pomerium/go-oidc"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// GoogleProviderName identifies the Google provider
|
||||||
|
GoogleProviderName = "google"
|
||||||
|
// OIDCProviderName identifes a generic OpenID connect provider
|
||||||
|
OIDCProviderName = "oidc"
|
||||||
|
// OktaProviderName identifes the Okta identity provider
|
||||||
|
OktaProviderName = "okta"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider is an interface exposing functions necessary to authenticate with a given provider.
|
||||||
|
type Provider interface {
|
||||||
|
Data() *ProviderData
|
||||||
|
Redeem(string) (*sessions.SessionState, error)
|
||||||
|
ValidateSessionState(*sessions.SessionState) bool
|
||||||
|
GetSignInURL(state string) string
|
||||||
|
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||||
|
Revoke(*sessions.SessionState) error
|
||||||
|
RefreshAccessToken(string) (string, time.Duration, error)
|
||||||
|
// Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new identity provider based on available name.
|
||||||
|
// Defaults to google.
|
||||||
|
func New(provider string, p *ProviderData) (Provider, error) {
|
||||||
|
switch provider {
|
||||||
|
case OIDCProviderName:
|
||||||
|
p, err := NewOIDCProvider(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
case OktaProviderName:
|
||||||
|
log.Info().Msg("Okta!")
|
||||||
|
p, err := NewOktaProvider(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
default:
|
||||||
|
p, err := NewGoogleProvider(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderData holds the fields associated with providers
|
||||||
|
// necessary to implement the Provider interface.
|
||||||
|
type ProviderData struct {
|
||||||
|
RedirectURL *url.URL
|
||||||
|
ProviderName string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
ProviderURL *url.URL
|
||||||
|
Scopes []string
|
||||||
|
ApprovalPrompt string
|
||||||
|
SessionLifetimeTTL time.Duration
|
||||||
|
|
||||||
|
verifier *oidc.IDTokenVerifier
|
||||||
|
oauth *oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data returns a ProviderData.
|
||||||
|
func (p *ProviderData) Data() *ProviderData { return p }
|
||||||
|
|
||||||
|
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||||
|
func (p *ProviderData) GetSignInURL(state string) string {
|
||||||
|
return p.oauth.AuthCodeURL(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSessionState validates a given session's from it's JWT token
|
||||||
|
// The function verifies it's been signed by the provider, preforms
|
||||||
|
// any additional checks depending on the Config, and returns the payload.
|
||||||
|
//
|
||||||
|
// ValidateSessionState does NOT do nonce validation.
|
||||||
|
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := p.verifier.Verify(ctx, s.IDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers.ValidateSessionState : failed to verify session state")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem creates a session with an identity provider from a authorization code
|
||||||
|
func (p *ProviderData) Redeem(code string) (*sessions.SessionState, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
// convert authorization code into a token
|
||||||
|
token, err := p.oauth.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers.Redeem : token exchange failed")
|
||||||
|
return nil, fmt.Errorf("token exchange: %v", err)
|
||||||
|
}
|
||||||
|
s, err := p.createSessionState(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers.Redeem : unable to update session")
|
||||||
|
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSessionIfNeeded will refresh the session state if it's deadline is expired
|
||||||
|
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
|
if !sessionRefreshRequired(s) {
|
||||||
|
log.Info().Msg("authenticate/providers.RefreshSessionIfNeeded : session refresh not needed")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
origExpiration := s.RefreshDeadline
|
||||||
|
err := p.redeemRefreshToken(s)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers.RefreshSession")
|
||||||
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("authenticate/providers.Redeem refreshed id token %s (expired on %s)", s, origExpiration)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProviderData) redeemRefreshToken(s *sessions.SessionState) error {
|
||||||
|
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 1")
|
||||||
|
ctx := context.Background()
|
||||||
|
t := &oauth2.Token{
|
||||||
|
RefreshToken: s.RefreshToken,
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
}
|
||||||
|
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 3")
|
||||||
|
|
||||||
|
// returns a TokenSource automatically refreshing it as necessary using the provided context
|
||||||
|
token, err := p.oauth.TokenSource(ctx, t).Token()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers failed to get token")
|
||||||
|
return fmt.Errorf("failed to get token: %v", err)
|
||||||
|
}
|
||||||
|
log.Info().Msg("authenticate/providers.oidc.redeemRefreshToken 4")
|
||||||
|
|
||||||
|
newSession, err := p.createSessionState(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers unable to update session")
|
||||||
|
return fmt.Errorf("unable to update session: %v", err)
|
||||||
|
}
|
||||||
|
s.AccessToken = newSession.AccessToken
|
||||||
|
s.IDToken = newSession.IDToken
|
||||||
|
s.RefreshToken = newSession.RefreshToken
|
||||||
|
s.RefreshDeadline = newSession.RefreshDeadline
|
||||||
|
s.Email = newSession.Email
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("AccessToken", s.AccessToken).
|
||||||
|
Str("IdToken", s.IDToken).
|
||||||
|
Time("RefreshDeadline", s.RefreshDeadline).
|
||||||
|
Str("RefreshToken", s.RefreshToken).
|
||||||
|
Str("Email", s.Email).
|
||||||
|
Msg("authenticate/providers.redeemRefreshToken")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProviderData) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||||
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||||
|
}
|
||||||
|
log.Info().
|
||||||
|
Bool("ctx", ctx == nil).
|
||||||
|
Bool("Verifier", p.verifier == nil).
|
||||||
|
Str("rawIDToken", rawIDToken).
|
||||||
|
Msg("authenticate/providers.oidc.createSessionState 2")
|
||||||
|
|
||||||
|
// Parse and verify ID Token payload.
|
||||||
|
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers could not verify id_token")
|
||||||
|
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract custom claims.
|
||||||
|
var claims struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Verified *bool `json:"email_verified"`
|
||||||
|
}
|
||||||
|
// parse claims from the raw, encoded jwt token
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse id_token claims: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Email == "" {
|
||||||
|
return nil, fmt.Errorf("id_token did not contain an email")
|
||||||
|
}
|
||||||
|
if claims.Verified != nil && !*claims.Verified {
|
||||||
|
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &sessions.SessionState{
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
IDToken: rawIDToken,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
RefreshDeadline: token.Expiry,
|
||||||
|
LifetimeDeadline: token.Expiry,
|
||||||
|
Email: claims.Email,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccessToken allows the service to refresh an access token without
|
||||||
|
// prompting the user for permission.
|
||||||
|
func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
|
||||||
|
if refreshToken == "" {
|
||||||
|
return "", 0, errors.New("missing refresh token")
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
c := oauth2.Config{
|
||||||
|
ClientID: p.ClientID,
|
||||||
|
ClientSecret: p.ClientSecret,
|
||||||
|
Endpoint: oauth2.Endpoint{TokenURL: p.ProviderURL.String()},
|
||||||
|
}
|
||||||
|
t := oauth2.Token{RefreshToken: refreshToken}
|
||||||
|
ts := c.TokenSource(ctx, &t)
|
||||||
|
log.Info().
|
||||||
|
Str("RefreshToken", refreshToken).
|
||||||
|
Msg("authenticate/providers.RefreshAccessToken")
|
||||||
|
|
||||||
|
newToken, err := ts.Token()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke enables a user to revoke her tokenn. Though many providers such as
|
||||||
|
// google and okta provide revoke endpoints, since it's not officially supported
|
||||||
|
// as part of OpenID Connect, the default implementation throws an error.
|
||||||
|
func (p *ProviderData) Revoke(s *sessions.SessionState) error {
|
||||||
|
return errors.New("revoke not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionRefreshRequired(s *sessions.SessionState) bool {
|
||||||
|
return s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == ""
|
||||||
|
|
||||||
|
}
|
142
authenticate/providers/singleflight_middleware.go
Normal file
142
authenticate/providers/singleflight_middleware.go
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Provider = &SingleFlightProvider{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrUnexpectedReturnType is an error for an unexpected return type
|
||||||
|
var (
|
||||||
|
ErrUnexpectedReturnType = errors.New("received unexpected return type from single flight func call")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SingleFlightProvider middleware provider that multiple requests for the same object
|
||||||
|
// to be processed as a single request. This is often called request collpasing or coalesce.
|
||||||
|
// This middleware leverages the golang singlelflight provider, with modifications for metrics.
|
||||||
|
//
|
||||||
|
// It's common among HTTP reverse proxy cache servers such as nginx, Squid or Varnish - they all call it something else but works similarly.
|
||||||
|
//
|
||||||
|
// * https://www.varnish-cache.org/docs/3.0/tutorial/handling_misbehaving_servers.html
|
||||||
|
// * http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_cache_lock
|
||||||
|
// * http://wiki.squid-cache.org/Features/CollapsedForwarding
|
||||||
|
type SingleFlightProvider struct {
|
||||||
|
provider Provider
|
||||||
|
|
||||||
|
single *singleflight.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSingleFlightProvider returns a new SingleFlightProvider
|
||||||
|
func NewSingleFlightProvider(provider Provider) *SingleFlightProvider {
|
||||||
|
return &SingleFlightProvider{
|
||||||
|
provider: provider,
|
||||||
|
single: &singleflight.Group{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||||
|
compositeKey := fmt.Sprintf("%s/%s", endpoint, key)
|
||||||
|
resp, _, err := p.single.Do(compositeKey, fn)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data returns the provider data
|
||||||
|
func (p *SingleFlightProvider) Data() *ProviderData {
|
||||||
|
return p.provider.Data()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem wraps the provider's Redeem function.
|
||||||
|
func (p *SingleFlightProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||||
|
return p.provider.Redeem(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSessionState wraps the provider's ValidateSessionState in a single flight call.
|
||||||
|
func (p *SingleFlightProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
|
response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) {
|
||||||
|
valid := p.provider.ValidateSessionState(s)
|
||||||
|
return valid, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, ok := response.(bool)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignInURL calls the provider's GetSignInURL function.
|
||||||
|
func (p *SingleFlightProvider) GetSignInURL(finalRedirect string) string {
|
||||||
|
return p.provider.GetSignInURL(finalRedirect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSessionIfNeeded wraps the provider's RefreshSessionIfNeeded function in a single flight
|
||||||
|
// call.
|
||||||
|
func (p *SingleFlightProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
|
response, err := p.do("RefreshSessionIfNeeded", s.RefreshToken, func() (interface{}, error) {
|
||||||
|
return p.provider.RefreshSessionIfNeeded(s)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r, ok := response.(bool)
|
||||||
|
if !ok {
|
||||||
|
return false, ErrUnexpectedReturnType
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke wraps the provider's Revoke function in a single flight call.
|
||||||
|
func (p *SingleFlightProvider) Revoke(s *sessions.SessionState) error {
|
||||||
|
_, err := p.do("Revoke", s.AccessToken, func() (interface{}, error) {
|
||||||
|
err := p.provider.Revoke(s)
|
||||||
|
return nil, err
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccessToken wraps the provider's RefreshAccessToken function in a single flight call.
|
||||||
|
func (p *SingleFlightProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
|
||||||
|
type Response struct {
|
||||||
|
AccessToken string
|
||||||
|
ExpiresIn time.Duration
|
||||||
|
}
|
||||||
|
response, err := p.do("RefreshAccessToken", refreshToken, func() (interface{}, error) {
|
||||||
|
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Response{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r, ok := response.(*Response)
|
||||||
|
if !ok {
|
||||||
|
return "", 0, ErrUnexpectedReturnType
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.AccessToken, r.ExpiresIn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// // Stop calls the provider's stop function
|
||||||
|
// func (p *SingleFlightProvider) Stop() {
|
||||||
|
// p.provider.Stop()
|
||||||
|
// }
|
81
authenticate/providers/test_provider.go
Normal file
81
authenticate/providers/test_provider.go
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestProvider is a test implementation of the Provider interface.
|
||||||
|
type TestProvider struct {
|
||||||
|
*ProviderData
|
||||||
|
|
||||||
|
ValidToken bool
|
||||||
|
ValidGroup bool
|
||||||
|
SignInURL string
|
||||||
|
Refresh bool
|
||||||
|
RefreshFunc func(string) (string, time.Duration, error)
|
||||||
|
RefreshError error
|
||||||
|
Session *sessions.SessionState
|
||||||
|
RedeemError error
|
||||||
|
RevokeError error
|
||||||
|
Groups []string
|
||||||
|
GroupsError error
|
||||||
|
GroupsCall int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestProvider creates a new mock test provider.
|
||||||
|
func NewTestProvider(providerURL *url.URL) *TestProvider {
|
||||||
|
return &TestProvider{
|
||||||
|
ProviderData: &ProviderData{
|
||||||
|
ProviderName: "Test Provider",
|
||||||
|
ProviderURL: &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: providerURL.Host,
|
||||||
|
Path: "/authorize",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSessionState returns the mock provider's ValidToken field value.
|
||||||
|
func (tp *TestProvider) ValidateSessionState(*sessions.SessionState) bool {
|
||||||
|
return tp.ValidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignInURL returns the mock provider's SignInURL field value.
|
||||||
|
func (tp *TestProvider) GetSignInURL(finalRedirect string) string {
|
||||||
|
return tp.SignInURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSessionIfNeeded returns the mock provider's Refresh value, or an error.
|
||||||
|
func (tp *TestProvider) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) {
|
||||||
|
return tp.Refresh, tp.RefreshError
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccessToken returns the mock provider's refresh access token information
|
||||||
|
func (tp *TestProvider) RefreshAccessToken(s string) (string, time.Duration, error) {
|
||||||
|
return tp.RefreshFunc(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke returns nil
|
||||||
|
func (tp *TestProvider) Revoke(*sessions.SessionState) error {
|
||||||
|
return tp.RevokeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateGroupMembership returns the mock provider's GroupsError if not nil, or the Groups field value.
|
||||||
|
func (tp *TestProvider) ValidateGroupMembership(string, []string) ([]string, error) {
|
||||||
|
return tp.Groups, tp.GroupsError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem returns the mock provider's Session and RedeemError field value.
|
||||||
|
func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||||
|
return tp.Session, tp.RedeemError
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop fulfills the Provider interface
|
||||||
|
func (tp *TestProvider) Stop() {
|
||||||
|
return
|
||||||
|
}
|
12
authorize/README.md
Normal file
12
authorize/README.md
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Authorize
|
||||||
|
|
||||||
|
## What's this package do?
|
||||||
|
The authorize packages makes a binary determination of access.
|
||||||
|
|
||||||
|
Authorization is on trust from:
|
||||||
|
- Device state (vulnerability scanned?, MDM?, BYOD? Encrypted?)
|
||||||
|
- User standing (HR status, Groups, etc)
|
||||||
|
- Context (time, location, role)
|
||||||
|
Driven by:
|
||||||
|
- Dynamic "policy as code", fine grained policy
|
||||||
|
- Machine Learning & anomaly detection based on multiple input sources
|
71
cmd/pomerium/main.go
Normal file
71
cmd/pomerium/main.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/authenticate"
|
||||||
|
"github.com/pomerium/pomerium/internal/https"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/options"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"github.com/pomerium/pomerium/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debugFlag = flag.Bool("debug", false, "run server in debug mode, changes log output to STDOUT and level to info")
|
||||||
|
versionFlag = flag.Bool("version", false, "prints the version")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
if *debugFlag {
|
||||||
|
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout})
|
||||||
|
}
|
||||||
|
if *versionFlag {
|
||||||
|
fmt.Printf("%s", version.FullVersion())
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
log.Info().Str("version", version.FullVersion()).Str("user-agent", version.UserAgent()).Msg("cmd/pomerium")
|
||||||
|
authOpts, err := authenticate.OptionsFromEnvConfig()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("cmd/pomerium : failed to parse authenticator settings")
|
||||||
|
}
|
||||||
|
emailValidator := func(p *authenticate.Authenticator) error {
|
||||||
|
p.Validator = options.NewEmailValidator(authOpts.EmailDomains)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticator, err := authenticate.NewAuthenticator(authOpts, emailValidator)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("cmd/pomerium : failed to create authenticator")
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyOpts, err := proxy.OptionsFromEnvConfig()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("cmd/pomerium : failed to parse proxy settings")
|
||||||
|
}
|
||||||
|
|
||||||
|
validator := func(p *proxy.Proxy) error {
|
||||||
|
p.EmailValidator = options.NewEmailValidator(proxyOpts.EmailDomains)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := proxy.NewProxy(proxyOpts, validator)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("cmd/pomerium : failed to create proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyHandler := log.NewLoggingHandler(p.Handler())
|
||||||
|
authHandler := http.TimeoutHandler(authenticator.Handler(), authOpts.RequestTimeout, "")
|
||||||
|
|
||||||
|
topMux := http.NewServeMux()
|
||||||
|
topMux.Handle(authOpts.Host+"/", authHandler)
|
||||||
|
topMux.Handle("/", p.Handler())
|
||||||
|
log.Fatal().Err(https.ListenAndServeTLS(nil, topMux))
|
||||||
|
|
||||||
|
}
|
BIN
docs/logo.png
Normal file
BIN
docs/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.8 KiB |
26
env.example
Normal file
26
env.example
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export HOST="sso-auth.corp.beyondperimeter.com"
|
||||||
|
export REDIRECT_URL="https://sso-auth.corp.beyondperimeter.com/oauth2/callback"
|
||||||
|
export PROXY_ROOT_DOMAIN=beyondperimeter.com
|
||||||
|
export PROXY_CLIENT_ID=WLgwUNIJW6DtsnAM2ck510znU2T3l+WufPg67e50oVM=
|
||||||
|
export PROXY_CLIENT_SECRET=gFB0qsSxxPqCtoNMuF7Q1VupJSNEq0BguxlUfT0PE+Y=
|
||||||
|
|
||||||
|
# Generate 256 bitrandom key to encrypt the cookie `head -c32 /dev/urandom | base64`
|
||||||
|
export AUTH_CODE_SECRET=9wiTZq4qvmS/plYQyvzGKWPlH/UBy0DMYMA2x/zngrM=
|
||||||
|
export COOKIE_SECRET=uPGHo1ujND/k3B9V6yr52Gweq3RRYfFho98jxDG5Br8=
|
||||||
|
export COOKIE_SECURE=true
|
||||||
|
|
||||||
|
# Valid email domains
|
||||||
|
export EMAIL_DOMAIN=*
|
||||||
|
export SSO_EMAIL_DOMAIN=*
|
||||||
|
|
||||||
|
# IdP configuration
|
||||||
|
export IDP_PROVIDER="google"
|
||||||
|
export IDP_PROVIDER_URL="https://sso-auth.corp.beyondperimeter.com"
|
||||||
|
export IDP_CLIENT_ID="xxx.apps.googleusercontent.com"
|
||||||
|
export IDP_CLIENT_SECRET="xxx"
|
||||||
|
export IDP_REDIRECT_URL="https://sso-auth.corp.beyondperimeter.com/oauth2/callback"
|
||||||
|
|
||||||
|
# proxy'd routes
|
||||||
|
export ROUTES='news.corp.beyondperimeter.com':'news.ycombinator.com','github.corp.beyondperimeter.com':'github.com'
|
19
go.mod
Normal file
19
go.mod
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
module github.com/pomerium/pomerium
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d
|
||||||
|
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||||
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
|
github.com/rs/zerolog v1.11.0
|
||||||
|
github.com/stretchr/testify v1.2.2 // indirect
|
||||||
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
||||||
|
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
||||||
|
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
||||||
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
|
||||||
|
google.golang.org/appengine v1.4.0 // indirect
|
||||||
|
gopkg.in/square/go-jose.v2 v2.2.1 // indirect
|
||||||
|
)
|
34
go.sum
Normal file
34
go.sum
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
|
||||||
|
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
|
||||||
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b h1:VPhrxAgvd0d0xSZP4P8zzWZntso2NE0m69dmDEXM53I=
|
||||||
|
github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b/go.mod h1:Vj6lPE3LxPymcFxg7hm9aDIJWCyhJMnxSNC/y9ZHtN8=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d h1:iAavGhspmmDa8LCcQLpHV/JBPWKL7EfvMN7YrGHBN3o=
|
||||||
|
github.com/pomerium/envconfig v1.3.1-0.20180517194557-dd1402a4d99d/go.mod h1:1Kz8Ca8PhJDtLYqgvbDZGn6GsJCvrT52SxQ3sPNJkDc=
|
||||||
|
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
||||||
|
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
|
||||||
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||||
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||||
|
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
|
||||||
|
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||||
|
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||||
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
||||||
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis=
|
||||||
|
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE=
|
||||||
|
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
|
||||||
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||||
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
|
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic=
|
||||||
|
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
172
internal/aead/aead.go
Normal file
172
internal/aead/aead.go
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
miscreant "github.com/miscreant/miscreant-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
const miscreantNonceSize = 16
|
||||||
|
|
||||||
|
var algorithmType = "AES-CMAC-SIV"
|
||||||
|
|
||||||
|
// Cipher provides methods to encrypt and decrypt values.
|
||||||
|
type Cipher interface {
|
||||||
|
Encrypt([]byte) ([]byte, error)
|
||||||
|
Decrypt([]byte) ([]byte, error)
|
||||||
|
Marshal(interface{}) (string, error)
|
||||||
|
Unmarshal(string, interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiscreantCipher provides methods to encrypt and decrypt values.
|
||||||
|
// Using an AEAD is a cipher providing authenticated encryption with associated data.
|
||||||
|
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
|
||||||
|
type MiscreantCipher struct {
|
||||||
|
aead cipher.AEAD
|
||||||
|
|
||||||
|
mux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMiscreantCipher returns a new AES Cipher for encrypting values
|
||||||
|
func NewMiscreantCipher(secret []byte) (*MiscreantCipher, error) {
|
||||||
|
aead, err := miscreant.NewAEAD(algorithmType, secret, miscreantNonceSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &MiscreantCipher{
|
||||||
|
aead: aead,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateKey wraps miscreant's GenerateKey function
|
||||||
|
func GenerateKey() []byte {
|
||||||
|
return miscreant.GenerateKey(32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt a value using AES-CMAC-SIV
|
||||||
|
func (c *MiscreantCipher) Encrypt(plaintext []byte) (joined []byte, err error) {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("miscreant error encrypting bytes: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
nonce := miscreant.GenerateNonce(c.aead)
|
||||||
|
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
||||||
|
|
||||||
|
// we return the nonce as part of the returned value
|
||||||
|
joined = append(ciphertext[:], nonce[:]...)
|
||||||
|
return joined, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt a value using AES-CMAC-SIV
|
||||||
|
func (c *MiscreantCipher) Decrypt(joined []byte) ([]byte, error) {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if len(joined) <= miscreantNonceSize {
|
||||||
|
return nil, fmt.Errorf("invalid input size: %d", len(joined))
|
||||||
|
}
|
||||||
|
// grab out the nonce
|
||||||
|
pivot := len(joined) - miscreantNonceSize
|
||||||
|
ciphertext := joined[:pivot]
|
||||||
|
nonce := joined[pivot:]
|
||||||
|
|
||||||
|
plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
|
||||||
|
// and base64 encodes the binary value as a string and returns the result
|
||||||
|
func (c *MiscreantCipher) Marshal(s interface{}) (string, error) {
|
||||||
|
// encode json value
|
||||||
|
plaintext, err := json.Marshal(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
compressed, err := compress(plaintext)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// encrypt the JSON
|
||||||
|
ciphertext, err := c.Encrypt(compressed)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// base64-encode the result
|
||||||
|
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
|
||||||
|
return encoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||||
|
// byte slice the pased cipher, and unmarshals the resulting JSON into the struct pointer passed
|
||||||
|
func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error {
|
||||||
|
// convert base64 string value to bytes
|
||||||
|
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt the bytes
|
||||||
|
compressed, err := c.Decrypt(ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// decompress
|
||||||
|
plaintext, err := decompress(compressed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// unmarshal bytes
|
||||||
|
err = json.Unmarshal(plaintext, s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func compress(data []byte) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("aead/compress: failed to create a gzip writer: %q", err)
|
||||||
|
}
|
||||||
|
if writer == nil {
|
||||||
|
return nil, fmt.Errorf("aead/compress: failed to create a gzip writer")
|
||||||
|
}
|
||||||
|
if _, err = writer.Write(data); err != nil {
|
||||||
|
return nil, fmt.Errorf("aead/compress: failed to compress data with err: %q", err)
|
||||||
|
}
|
||||||
|
if err = writer.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decompress(data []byte) ([]byte, error) {
|
||||||
|
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("aead/compress: failed to create a gzip reader: %q", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, err = io.Copy(&buf, reader); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
173
internal/aead/aead_test.go
Normal file
173
internal/aead/aead_test.go
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
|
plaintext := []byte("my plain text value")
|
||||||
|
|
||||||
|
key := GenerateKey()
|
||||||
|
c, err := NewMiscreantCipher([]byte(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext, err := c.Encrypt(plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.DeepEqual(plaintext, ciphertext) {
|
||||||
|
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := c.Decrypt(ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err decrypting: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, plaintext) {
|
||||||
|
t.Logf(" got: %v", got)
|
||||||
|
t.Logf("want: %v", plaintext)
|
||||||
|
t.Fatal("got unexpected decrypted value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
|
key := GenerateKey()
|
||||||
|
|
||||||
|
c, err := NewMiscreantCipher([]byte(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TC struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tc := &TC{
|
||||||
|
Field: "my plain text value",
|
||||||
|
}
|
||||||
|
|
||||||
|
value1, err := c.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value2, err := c.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value1 == value2 {
|
||||||
|
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
||||||
|
}
|
||||||
|
|
||||||
|
got1 := &TC{}
|
||||||
|
err = c.Unmarshal(value1, got1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got1, tc) {
|
||||||
|
t.Logf("want: %#v", tc)
|
||||||
|
t.Logf(" got: %#v", got1)
|
||||||
|
t.Fatalf("expected structs to be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
got2 := &TC{}
|
||||||
|
err = c.Unmarshal(value2, got2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got1, got2) {
|
||||||
|
t.Logf("got2: %#v", got2)
|
||||||
|
t.Logf("got1: %#v", got1)
|
||||||
|
t.Fatalf("expected structs to be equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCipherDataRace exercises a simple concurrency test for the MicscreantCipher.
|
||||||
|
// In https://github.com/pomerium/pomerium/pull/75 we investigated why, on random occasion,
|
||||||
|
// unmarshalling session states would fail, triggering users to get kicked out of
|
||||||
|
// authenticated states. We narrowed our investigation to a data race we uncovered
|
||||||
|
// from our misuse of the underlying miscreant library which makes no attempt
|
||||||
|
// at thread-safety.
|
||||||
|
//
|
||||||
|
// In https://github.com/pomerium/pomerium/pull/77 we added this test to exercise the
|
||||||
|
// data race condition and resolved said race by introducing a simple mutex.
|
||||||
|
func TestCipherDataRace(t *testing.T) {
|
||||||
|
miscreantCipher, err := NewMiscreantCipher(GenerateKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected generating cipher err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TC struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(c *MiscreantCipher, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
b := make([]byte, 32)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexecpted error reading random bytes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
|
||||||
|
tc := &TC{
|
||||||
|
Field: sha,
|
||||||
|
}
|
||||||
|
|
||||||
|
value1, err := c.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value2, err := c.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value1 == value2 {
|
||||||
|
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
||||||
|
}
|
||||||
|
|
||||||
|
got1 := &TC{}
|
||||||
|
err = c.Unmarshal(value1, got1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got1, tc) {
|
||||||
|
t.Logf("want: %#v", tc)
|
||||||
|
t.Logf(" got: %#v", got1)
|
||||||
|
t.Fatalf("expected structs to be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
got2 := &TC{}
|
||||||
|
err = c.Unmarshal(value2, got2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got1, got2) {
|
||||||
|
t.Logf("got2: %#v", got2)
|
||||||
|
t.Logf("got1: %#v", got1)
|
||||||
|
t.Fatalf("expected structs to be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
}(miscreantCipher, wg)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
34
internal/aead/mock_cipher.go
Normal file
34
internal/aead/mock_cipher.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCipher is a mock of the cipher interface
|
||||||
|
type MockCipher struct {
|
||||||
|
MarshalError error
|
||||||
|
MarshalString string
|
||||||
|
UnmarshalError error
|
||||||
|
UnmarshalBytes []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt returns an empty byte array and nil
|
||||||
|
func (mc *MockCipher) Encrypt([]byte) ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt returns an empty byte array and nil
|
||||||
|
func (mc *MockCipher) Decrypt([]byte) ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal returns the marshal string and marsha error
|
||||||
|
func (mc *MockCipher) Marshal(interface{}) (string, error) {
|
||||||
|
return mc.MarshalString, mc.MarshalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal unmarshals the unmarshal bytes to be set in s and returns the unmarshal error
|
||||||
|
func (mc *MockCipher) Unmarshal(b string, s interface{}) error {
|
||||||
|
json.Unmarshal(mc.UnmarshalBytes, s)
|
||||||
|
return mc.UnmarshalError
|
||||||
|
}
|
36
internal/cryptutil/README.md
Normal file
36
internal/cryptutil/README.md
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
|
||||||
|
## Generating random seeds
|
||||||
|
In order of preference:
|
||||||
|
- `head -c32 /dev/urandom | base64`
|
||||||
|
- `openssl rand -base64 32 | head -c 32 | base64`
|
||||||
|
## Encrypting data
|
||||||
|
|
||||||
|
TL;DR -- Nonce reuse is a problem. AEAD isn't a clear choice right now.
|
||||||
|
|
||||||
|
[Miscreant](https://github.com/miscreant/miscreant.go)
|
||||||
|
+ AES-GCM-SIV seems to have ideal properties
|
||||||
|
+ random nonces
|
||||||
|
- ~30% slower encryption
|
||||||
|
- [not maintained by a BigCo](https://github.com/miscreant/miscreant.go/graphs/contributors)
|
||||||
|
|
||||||
|
[nacl/secretbot](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
|
||||||
|
+ Fast
|
||||||
|
+ XSalsa20 wutg Poly1305 MAC provides encryption and authentication together
|
||||||
|
+ A newer standard and may not be considered acceptable in environments that require high levels of review.
|
||||||
|
-/+ maintained as an [/x/ package](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
|
||||||
|
- doesn't use the underlying cipher.AEAD api.
|
||||||
|
|
||||||
|
|
||||||
|
GCM with random nonces
|
||||||
|
+ Fastest
|
||||||
|
+ Go standard library, supported by google $
|
||||||
|
- Easy to get wrong
|
||||||
|
- IV reuse is a known weakness so keys must be rotated before birthday attack. [NIST SP 800-38D](http://csrc.nist.gov/publications/nistpubs/800-38D/SP-800-38D.pdf) recommends using the same key with random 96-bit nonces (the default nonce length) no more than 2^32 times
|
||||||
|
|
||||||
|
Further reading on tradeoffs:
|
||||||
|
- [Introducing Miscreant](https://tonyarcieri.com/introducing-miscreant-a-multi-language-misuse-resistant-encryption-library)
|
||||||
|
- [agl's post AES-GCM-SIV](https://www.imperialviolet.org/2017/05/14/aesgcmsiv.html)
|
||||||
|
- [x/crypto: add chacha20, xchacha20](https://github.com/golang/go/issues/24485s)
|
||||||
|
- [GCM cannot be used with random nonces](https://github.com/gtank/cryptopasta/issues/14s)
|
||||||
|
- [proposal: x/crypto/chacha20poly1305: add support for XChaCha20](https://github.com/golang/go/issues/23885)
|
||||||
|
- [kubernetes](https://kubernetes.io/docs/tasks/administer-cluster/encrypt-data/#providers)
|
30
internal/cryptutil/hash.go
Normal file
30
internal/cryptutil/hash.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha512"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to
|
||||||
|
// be a natural-language string describing the purpose of the hash, such as
|
||||||
|
// "hash file for lookup key" or "master secret to client secret". It serves
|
||||||
|
// as an HMAC "key" and ensures that different purposes will have different
|
||||||
|
// hash output. This function is NOT suitable for hashing passwords.
|
||||||
|
func Hash(tag string, data []byte) []byte {
|
||||||
|
h := hmac.New(sha512.New512_256, []byte(tag))
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashPassword generates a bcrypt hash of the password using work factor 14.
|
||||||
|
func HashPassword(password []byte) ([]byte, error) {
|
||||||
|
return bcrypt.GenerateFromPassword(password, 14)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPasswordHash securely compares a bcrypt hashed password with its possible
|
||||||
|
// plaintext equivalent. Returns nil on success, or an error on failure.
|
||||||
|
func CheckPasswordHash(hash, password []byte) error {
|
||||||
|
return bcrypt.CompareHashAndPassword(hash, password)
|
||||||
|
}
|
80
internal/cryptutil/hash_test.go
Normal file
80
internal/cryptutil/hash_test.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPasswordHashing(t *testing.T) {
|
||||||
|
bcryptTests := []struct {
|
||||||
|
plaintext []byte
|
||||||
|
hash []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
plaintext: []byte("password"),
|
||||||
|
hash: []byte("$2a$14$uALAQb/Lwl59oHVbuUa5m.xEFmQBc9ME/IiSgJK/VHtNJJXASCDoS"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range bcryptTests {
|
||||||
|
hashed, err := HashPassword(tt.plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = CheckPasswordHash(hashed, tt.plaintext); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmarks SHA256 on 16K of random data.
|
||||||
|
func BenchmarkSHA256(b *testing.B) {
|
||||||
|
data, err := ioutil.ReadFile("testdata/random")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = sha256.Sum256(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmarks SHA512/256 on 16K of random data.
|
||||||
|
func BenchmarkSHA512_256(b *testing.B) {
|
||||||
|
data, err := ioutil.ReadFile("testdata/random")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = sha512.Sum512_256(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBcrypt(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := HashPassword([]byte("thisisareallybadpassword"))
|
||||||
|
if err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleHash() {
|
||||||
|
tag := "hashing file for lookup key"
|
||||||
|
contents, err := ioutil.ReadFile("testdata/random")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("could not read file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
digest := Hash(tag, contents)
|
||||||
|
fmt.Println(hex.EncodeToString(digest))
|
||||||
|
// Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1
|
||||||
|
}
|
BIN
internal/cryptutil/testdata/random
vendored
Normal file
BIN
internal/cryptutil/testdata/random
vendored
Normal file
Binary file not shown.
32
internal/fileutil/fileutil.go
Normal file
32
internal/fileutil/fileutil.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsReadableFile reports whether the file exists and is readable.
|
||||||
|
// If the error is non-nil, it means there might be a file or directory
|
||||||
|
// with that name but we cannot read it.
|
||||||
|
//
|
||||||
|
// Adapted from the upspin.io source code.
|
||||||
|
func IsReadableFile(path string) (bool, error) {
|
||||||
|
// Is it stattable and is it a plain file?
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false, nil // Item does not exist.
|
||||||
|
}
|
||||||
|
return false, err // Item is problematic.
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
return false, errors.New("is directory")
|
||||||
|
}
|
||||||
|
// Is it readable?
|
||||||
|
fd, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.New("permission denied")
|
||||||
|
}
|
||||||
|
fd.Close()
|
||||||
|
return true, nil // Item exists and is readable.
|
||||||
|
}
|
29
internal/fileutil/fileutil_test.go
Normal file
29
internal/fileutil/fileutil_test.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package fileutil // import "github.com/pomerium/pomerium/internal/fileutil"
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsReadableFile(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args string
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good file", "fileutil.go", true, false},
|
||||||
|
{"file doesn't exist", "file-no-exist/nope", false, false},
|
||||||
|
{"can't read dir", "./", false, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := IsReadableFile(tt.args)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("IsReadableFile() error = %+v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsReadableFile() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
125
internal/https/https.go
Normal file
125
internal/https/https.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package https // import "github.com/pomerium/pomerium/internal/https"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/fileutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options contains the configurations settings for a TLS http server
|
||||||
|
type Options struct {
|
||||||
|
// Addr specifies the host and port on which the server should serve
|
||||||
|
// HTTPS requests. If empty, ":https" is used.
|
||||||
|
Addr string
|
||||||
|
|
||||||
|
// CertFile and KeyFile specifies the TLS certificates to use.
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultOptions = &Options{
|
||||||
|
Addr: ":https",
|
||||||
|
CertFile: filepath.Join(findKeyDir(), "cert.pem"),
|
||||||
|
KeyFile: filepath.Join(findKeyDir(), "privkey.pem"),
|
||||||
|
}
|
||||||
|
|
||||||
|
func findKeyDir() string {
|
||||||
|
p, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return "."
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (opt *Options) applyDefaults() {
|
||||||
|
if opt.Addr == "" {
|
||||||
|
opt.Addr = defaultOptions.Addr
|
||||||
|
}
|
||||||
|
if opt.CertFile == "" {
|
||||||
|
opt.CertFile = defaultOptions.CertFile
|
||||||
|
}
|
||||||
|
if opt.KeyFile == "" {
|
||||||
|
opt.KeyFile = defaultOptions.KeyFile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServeTLS serves the provided handlers by HTTPS
|
||||||
|
// using the provided options.
|
||||||
|
func ListenAndServeTLS(opt *Options, handler http.Handler) error {
|
||||||
|
if opt == nil {
|
||||||
|
opt = defaultOptions
|
||||||
|
} else {
|
||||||
|
opt.applyDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := newDefaultTLSConfig(opt.CertFile, opt.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("https: setting up TLS config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", opt.Addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ln = tls.NewListener(ln, config)
|
||||||
|
|
||||||
|
// Set up the main server.
|
||||||
|
server := &http.Server{
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
ReadTimeout: 15 * time.Second,
|
||||||
|
// WriteTimeout is set to 0 because it also pertains to
|
||||||
|
// streaming replies, e.g., the DirServer.Watch interface.
|
||||||
|
WriteTimeout: 0,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
TLSConfig: config,
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
return server.Serve(ln)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDefaultTLSConfig creates a new TLS config based on the certificate files given.
|
||||||
|
func newDefaultTLSConfig(certFile string, certKeyFile string) (*tls.Config, error) {
|
||||||
|
certReadable, err := fileutil.IsReadableFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("TLS certificate in %q: %q", certFile, err)
|
||||||
|
}
|
||||||
|
if !certReadable {
|
||||||
|
return nil, fmt.Errorf("certificate file %q not readable", certFile)
|
||||||
|
}
|
||||||
|
keyReadable, err := fileutil.IsReadableFile(certKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("TLS key in %q: %v", certKeyFile, err)
|
||||||
|
}
|
||||||
|
if !keyReadable {
|
||||||
|
return nil, fmt.Errorf("certificate key file %q not readable", certKeyFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, certKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
CipherSuites: []uint16{
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
PreferServerCipherSuites: true,
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
tlsConfig.BuildNameToCertificate()
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
87
internal/httputil/client.go
Normal file
87
internal/httputil/client.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||||
|
var ErrTokenRevoked = errors.New("Token expired or revoked")
|
||||||
|
|
||||||
|
var httpClient = &http.Client{
|
||||||
|
Timeout: time.Second * 5,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}).Dial,
|
||||||
|
TLSHandshakeTimeout: 2 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client provides a simple helper interface to make HTTP requests
|
||||||
|
func Client(method, endpoint, userAgent string, params url.Values, response interface{}) error {
|
||||||
|
var body io.Reader
|
||||||
|
switch method {
|
||||||
|
case "POST":
|
||||||
|
body = bytes.NewBufferString(params.Encode())
|
||||||
|
case "GET":
|
||||||
|
// error checking skipped because we are just parsing in
|
||||||
|
// order to make a copy of an existing URL
|
||||||
|
u, _ := url.Parse(endpoint)
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
endpoint = u.String()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(method, endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var respBody []byte
|
||||||
|
respBody, err = ioutil.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
var response struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
}
|
||||||
|
e := json.Unmarshal(respBody, &response)
|
||||||
|
if e == nil && response.ErrorDescription == "Token expired or revoked" {
|
||||||
|
return ErrTokenRevoked
|
||||||
|
}
|
||||||
|
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf(http.StatusText(resp.StatusCode))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if response != nil {
|
||||||
|
err := json.Unmarshal(respBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
82
internal/httputil/errors.go
Normal file
82
internal/httputil/errors.go
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrUserNotAuthorized is an error for unauthorized users.
|
||||||
|
ErrUserNotAuthorized = errors.New("user not authorized")
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPError stores the status code and a message for a given HTTP error.
|
||||||
|
type HTTPError struct {
|
||||||
|
Code int
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error fulfills the error interface, returning a string representation of the error.
|
||||||
|
func (h HTTPError) Error() string {
|
||||||
|
return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodeForError maps an error type and returns a corresponding http.Status
|
||||||
|
func CodeForError(err error) int {
|
||||||
|
switch err {
|
||||||
|
case ErrTokenRevoked:
|
||||||
|
return http.StatusUnauthorized
|
||||||
|
}
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorResponse renders an error page for errors given a message and a status code.
|
||||||
|
func ErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
||||||
|
if req.Header.Get("Accept") == "application/json" {
|
||||||
|
var response struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
response.Error = message
|
||||||
|
writeJSONResponse(rw, code, response)
|
||||||
|
} else {
|
||||||
|
title := http.StatusText(code)
|
||||||
|
|
||||||
|
log.Error().
|
||||||
|
Int("http-status", code).
|
||||||
|
Str("page-title", title).
|
||||||
|
Str("page-message", message).
|
||||||
|
Msg("authenticate/errors.ErrorResponse")
|
||||||
|
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
t := struct {
|
||||||
|
Code int
|
||||||
|
Title string
|
||||||
|
Message string
|
||||||
|
Version string
|
||||||
|
}{
|
||||||
|
Code: code,
|
||||||
|
Title: title,
|
||||||
|
Message: message,
|
||||||
|
Version: version.FullVersion(),
|
||||||
|
}
|
||||||
|
templates.New().ExecuteTemplate(rw, "error.html", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeJSONResponse is a helper that sets the application/json header and writes a response.
|
||||||
|
func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) {
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
|
||||||
|
err := json.NewEncoder(rw).Encode(response)
|
||||||
|
if err != nil {
|
||||||
|
io.WriteString(rw, err.Error())
|
||||||
|
}
|
||||||
|
}
|
129
internal/log/log.go
Normal file
129
internal/log/log.go
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
// Package log provides a global logger for zerolog.
|
||||||
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logger is the global logger.
|
||||||
|
var Logger = zerolog.New(os.Stderr).With().Timestamp().Logger()
|
||||||
|
|
||||||
|
// Output duplicates the global logger and sets w as its output.
|
||||||
|
func Output(w io.Writer) zerolog.Logger {
|
||||||
|
return Logger.Output(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With creates a child logger with the field added to its context.
|
||||||
|
func With() zerolog.Context {
|
||||||
|
return Logger.With()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequest creates a child logger with the remote user added to its context.
|
||||||
|
func WithRequest(req *http.Request, function string) zerolog.Logger {
|
||||||
|
remoteUser := getRemoteAddr(req)
|
||||||
|
return Logger.With().
|
||||||
|
Str("function", function).
|
||||||
|
Str("req-remote-user", remoteUser).
|
||||||
|
Str("req-http-method", req.Method).
|
||||||
|
Str("req-host", req.Host).
|
||||||
|
Str("req-url", req.URL.String()).
|
||||||
|
Str("req-user-agent", req.Header.Get("User-Agent")).
|
||||||
|
Logger()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Level creates a child logger with the minimum accepted level set to level.
|
||||||
|
func Level(level zerolog.Level) zerolog.Logger {
|
||||||
|
return Logger.Level(level)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample returns a logger with the s sampler.
|
||||||
|
func Sample(s zerolog.Sampler) zerolog.Logger {
|
||||||
|
return Logger.Sample(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hook returns a logger with the h Hook.
|
||||||
|
func Hook(h zerolog.Hook) zerolog.Logger {
|
||||||
|
return Logger.Hook(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug starts a new message with debug level.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Debug() *zerolog.Event {
|
||||||
|
return Logger.Debug()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info starts a new message with info level.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Info() *zerolog.Event {
|
||||||
|
return Logger.Info()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn starts a new message with warn level.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Warn() *zerolog.Event {
|
||||||
|
return Logger.Warn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error starts a new message with error level.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Error() *zerolog.Event {
|
||||||
|
return Logger.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fatal starts a new message with fatal level. The os.Exit(1) function
|
||||||
|
// is called by the Msg method.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Fatal() *zerolog.Event {
|
||||||
|
return Logger.Fatal()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Panic starts a new message with panic level. The message is also sent
|
||||||
|
// to the panic function.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Panic() *zerolog.Event {
|
||||||
|
return Logger.Panic()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLevel starts a new message with level.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func WithLevel(level zerolog.Level) *zerolog.Event {
|
||||||
|
return Logger.WithLevel(level)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log starts a new message with no level. Setting zerolog.GlobalLevel to
|
||||||
|
// zerolog.Disabled will still disable events produced by this method.
|
||||||
|
//
|
||||||
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
func Log() *zerolog.Event {
|
||||||
|
return Logger.Log()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print sends a log event using debug level and no extra field.
|
||||||
|
// Arguments are handled in the manner of fmt.Print.
|
||||||
|
func Print(v ...interface{}) {
|
||||||
|
Logger.Print(v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf sends a log event using debug level and no extra field.
|
||||||
|
// Arguments are handled in the manner of fmt.Printf.
|
||||||
|
func Printf(format string, v ...interface{}) {
|
||||||
|
Logger.Printf(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ctx returns the Logger associated with the ctx. If no logger
|
||||||
|
// is associated, a disabled logger is returned.
|
||||||
|
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||||
|
return zerolog.Ctx(ctx)
|
||||||
|
}
|
145
internal/log/request_logger.go
Normal file
145
internal/log/request_logger.go
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Used to stash the authenticated user in the response for access when logging requests.
|
||||||
|
const loggingUserHeader = "SSO-Authenticated-User"
|
||||||
|
const gapMetaDataHeader = "GAP-Auth"
|
||||||
|
|
||||||
|
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||||
|
// code and body size
|
||||||
|
type responseLogger struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
status int
|
||||||
|
size int
|
||||||
|
proxyHost string
|
||||||
|
authInfo string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) Header() http.Header {
|
||||||
|
return l.w.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) extractUser() {
|
||||||
|
authInfo := l.w.Header().Get(loggingUserHeader)
|
||||||
|
if authInfo != "" {
|
||||||
|
l.authInfo = authInfo
|
||||||
|
l.w.Header().Del(loggingUserHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) ExtractGAPMetadata() {
|
||||||
|
authInfo := l.w.Header().Get(gapMetaDataHeader)
|
||||||
|
if authInfo != "" {
|
||||||
|
l.authInfo = authInfo
|
||||||
|
|
||||||
|
l.w.Header().Del(gapMetaDataHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) Write(b []byte) (int, error) {
|
||||||
|
if l.status == 0 {
|
||||||
|
// The status will be StatusOK if WriteHeader has not been called yet
|
||||||
|
l.status = http.StatusOK
|
||||||
|
}
|
||||||
|
l.extractUser()
|
||||||
|
l.ExtractGAPMetadata()
|
||||||
|
|
||||||
|
size, err := l.w.Write(b)
|
||||||
|
l.size += size
|
||||||
|
return size, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) WriteHeader(s int) {
|
||||||
|
l.extractUser()
|
||||||
|
l.ExtractGAPMetadata()
|
||||||
|
|
||||||
|
l.w.WriteHeader(s)
|
||||||
|
l.status = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) Status() int {
|
||||||
|
return l.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) Size() int {
|
||||||
|
return l.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *responseLogger) Flush() {
|
||||||
|
f := l.w.(http.Flusher)
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends
|
||||||
|
type loggingHandler struct {
|
||||||
|
handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer.
|
||||||
|
func NewLoggingHandler(h http.Handler) http.Handler {
|
||||||
|
return loggingHandler{
|
||||||
|
handler: h,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
t := time.Now()
|
||||||
|
url := *req.URL
|
||||||
|
logger := &responseLogger{w: w, proxyHost: getProxyHost(req)}
|
||||||
|
h.handler.ServeHTTP(logger, req)
|
||||||
|
requestDuration := time.Since(t)
|
||||||
|
|
||||||
|
logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status())
|
||||||
|
}
|
||||||
|
|
||||||
|
// logRequest logs information about a request
|
||||||
|
func logRequest(proxyHost, username string, req *http.Request, url url.URL, requestDuration time.Duration, status int) {
|
||||||
|
uri := req.Host + url.RequestURI()
|
||||||
|
Info().
|
||||||
|
Int("http-status", status).
|
||||||
|
Str("request-method", req.Method).
|
||||||
|
Str("request-uri", uri).
|
||||||
|
Str("proxy-host", proxyHost).
|
||||||
|
Str("user-agent", req.Header.Get("User-Agent")).
|
||||||
|
Str("remote-address", getRemoteAddr(req)).
|
||||||
|
Dur("duration", requestDuration).
|
||||||
|
Str("user", username).
|
||||||
|
Msg("request")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRemoteAddr returns the client IP address from a request. If present, the
|
||||||
|
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||||
|
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||||
|
func getRemoteAddr(req *http.Request) string {
|
||||||
|
addr := req.RemoteAddr
|
||||||
|
forwardedHeader := req.Header.Get("X-Forwarded-For")
|
||||||
|
if forwardedHeader != "" {
|
||||||
|
forwardedList := strings.Split(forwardedHeader, ",")
|
||||||
|
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||||
|
if forwardedAddr != "" {
|
||||||
|
addr = forwardedAddr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProxyHost attempts to get the proxy host from the redirect_uri parameter
|
||||||
|
func getProxyHost(req *http.Request) string {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
redirect := req.Form.Get("redirect_uri")
|
||||||
|
redirectURL, err := url.Parse(redirect)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return redirectURL.Host
|
||||||
|
}
|
72
internal/log/request_logger_test.go
Normal file
72
internal/log/request_logger_test.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRemoteAddr(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
forwardedHeader string
|
||||||
|
expectedAddr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RemoteAddr used when no X-Forwarded-For header is given",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
expectedAddr: "1.1.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr used when no X-Forwarded-For header is only whitespace",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
forwardedHeader: " ",
|
||||||
|
expectedAddr: "1.1.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
forwardedHeader: " , , ",
|
||||||
|
expectedAddr: "1.1.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For header is preferred to RemoteAddr",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
forwardedHeader: "9.9.9.9",
|
||||||
|
expectedAddr: "9.9.9.9",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rightmost entry in X-Forwarded-For header is used",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5",
|
||||||
|
expectedAddr: "5.5.5.5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
forwardedHeader: "2.2.2.2, 3.3.3.3, ",
|
||||||
|
expectedAddr: "1.1.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwaded-For header entries are stripped",
|
||||||
|
remoteAddr: "1.1.1.1",
|
||||||
|
forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ",
|
||||||
|
expectedAddr: "5.5.5.5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = tc.remoteAddr
|
||||||
|
if tc.forwardedHeader != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tc.forwardedHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := getRemoteAddr(req)
|
||||||
|
if addr != tc.expectedAddr {
|
||||||
|
t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
205
internal/middleware/middleware.go
Normal file
205
internal/middleware/middleware.go
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetHeaders ensures that every response includes some basic security headers
|
||||||
|
func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
for key, val := range securityHeaders {
|
||||||
|
rw.Header().Set(key, val)
|
||||||
|
}
|
||||||
|
h.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMethods writes an error response if the method of the request is not included.
|
||||||
|
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
||||||
|
methodMap := make(map[string]struct{}, len(methods))
|
||||||
|
for _, m := range methods {
|
||||||
|
methodMap[m] = struct{}{}
|
||||||
|
}
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if _, ok := methodMap[req.Method]; !ok {
|
||||||
|
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateClientID checks the request body or url for the client id and returns an error
|
||||||
|
// if it does not match the proxy client id
|
||||||
|
func ValidateClientID(f http.HandlerFunc, proxyClientID string) http.HandlerFunc {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// try to get the client id from the request body
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := req.FormValue("client_id")
|
||||||
|
if clientID == "" {
|
||||||
|
// try to get the clientID from the query parameters
|
||||||
|
clientID = req.URL.Query().Get("client_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientID != proxyClientID {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid client_id parameter", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateClientSecret checks the request header for the client secret and returns
|
||||||
|
// an error if it does not match the proxy client secret
|
||||||
|
func ValidateClientSecret(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientSecret := req.Form.Get("client_secret")
|
||||||
|
// check the request header for the client secret
|
||||||
|
if clientSecret == "" {
|
||||||
|
clientSecret = req.Header.Get("X-Client-Secret")
|
||||||
|
}
|
||||||
|
if clientSecret != proxyClientSecret {
|
||||||
|
log.Error().Str("clientSecret", clientSecret).Str("proxyClientSecret", proxyClientSecret).Msg("oh")
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||||
|
// the url's domain is one in the list of proxy root domains.
|
||||||
|
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||||
|
redirectURL, err := url.Parse(uri)
|
||||||
|
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, domain := range rootDomains {
|
||||||
|
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSignature ensures the request is valid and has been signed with
|
||||||
|
// the correspdoning client secret key
|
||||||
|
func ValidateSignature(f http.HandlerFunc, proxyClientSecret string) http.HandlerFunc {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectURI := req.Form.Get("redirect_uri")
|
||||||
|
sigVal := req.Form.Get("sig")
|
||||||
|
timestamp := req.Form.Get("ts")
|
||||||
|
if !validSignature(redirectURI, sigVal, timestamp, proxyClientSecret) {
|
||||||
|
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateHost ensures that each request's host is valid
|
||||||
|
func ValidateHost(h http.Handler, mux map[string]*http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if _, ok := mux[req.Host]; !ok {
|
||||||
|
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||||
|
// todo(bdd) : this is unreliable unless behind another reverser proxy
|
||||||
|
func RequireHTTPS(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||||
|
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
|
||||||
|
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
|
||||||
|
dest := &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: req.Host,
|
||||||
|
Path: req.URL.Path,
|
||||||
|
RawQuery: req.URL.RawQuery,
|
||||||
|
}
|
||||||
|
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||||
|
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err := url.Parse(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
i, err := strconv.ParseInt(timestamp, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tm := time.Unix(i, 0)
|
||||||
|
ttl := 5 * time.Minute
|
||||||
|
if time.Since(tm) > ttl {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||||
|
return hmac.Equal(requestSig, localSig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||||
|
h := hmac.New(sha256.New, []byte(secret))
|
||||||
|
h.Write([]byte(rawRedirect))
|
||||||
|
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
35
internal/options/email_validator.go
Normal file
35
internal/options/email_validator.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package options // import "github.com/pomerium/pomerium/internal/options"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewEmailValidator returns a function that checks whether a given email is valid based on a list
|
||||||
|
// of domains. The domain "*" is a wild card that matches any non-empty email.
|
||||||
|
func NewEmailValidator(domains []string) func(string) bool {
|
||||||
|
allowAll := false
|
||||||
|
for i, domain := range domains {
|
||||||
|
if domain == "*" {
|
||||||
|
allowAll = true
|
||||||
|
}
|
||||||
|
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowAll {
|
||||||
|
return func(email string) bool { return email != "" }
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(email string) bool {
|
||||||
|
if email == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
email = strings.ToLower(email)
|
||||||
|
for _, domain := range domains {
|
||||||
|
if strings.HasSuffix(email, domain) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
121
internal/options/email_validator_test.go
Normal file
121
internal/options/email_validator_test.go
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
package options // import "github.com/pomerium/pomerium/internal/options"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEmailValidatorValidator(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
domains []string
|
||||||
|
email string
|
||||||
|
expectValid bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nothing should validate when domain list is empty",
|
||||||
|
domains: []string(nil),
|
||||||
|
email: "foo@example.com",
|
||||||
|
expectValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single domain validation",
|
||||||
|
domains: []string{"example.com"},
|
||||||
|
email: "foo@example.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "substring matches are rejected",
|
||||||
|
domains: []string{"example.com"},
|
||||||
|
email: "foo@hackerexample.com",
|
||||||
|
expectValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no subdomain rollup happens",
|
||||||
|
domains: []string{"example.com"},
|
||||||
|
email: "foo@bar.example.com",
|
||||||
|
expectValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple domain validation still rejects other domains",
|
||||||
|
domains: []string{"abc.com", "xyz.com"},
|
||||||
|
email: "foo@example.com",
|
||||||
|
expectValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple domain validation still accepts emails from either domain",
|
||||||
|
domains: []string{"abc.com", "xyz.com"},
|
||||||
|
email: "foo@abc.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple domain validation still rejects other domains",
|
||||||
|
domains: []string{"abc.com", "xyz.com"},
|
||||||
|
email: "bar@xyz.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comparisons are case insensitive",
|
||||||
|
domains: []string{"Example.Com"},
|
||||||
|
email: "foo@example.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comparisons are case insensitive",
|
||||||
|
domains: []string{"Example.Com"},
|
||||||
|
email: "foo@EXAMPLE.COM",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comparisons are case insensitive",
|
||||||
|
domains: []string{"example.com"},
|
||||||
|
email: "foo@ExAmPlE.CoM",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single wildcard allows all",
|
||||||
|
domains: []string{"*"},
|
||||||
|
email: "foo@example.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single wildcard allows all",
|
||||||
|
domains: []string{"*"},
|
||||||
|
email: "bar@gmail.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard in list allows all",
|
||||||
|
domains: []string{"example.com", "*"},
|
||||||
|
email: "foo@example.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard in list allows all",
|
||||||
|
domains: []string{"example.com", "*"},
|
||||||
|
email: "foo@gmail.com",
|
||||||
|
expectValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty email rejected",
|
||||||
|
domains: []string{"example.com"},
|
||||||
|
email: "",
|
||||||
|
expectValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard still rejects empty emails",
|
||||||
|
domains: []string{"*"},
|
||||||
|
email: "",
|
||||||
|
expectValid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
emailValidator := NewEmailValidator(tc.domains)
|
||||||
|
valid := emailValidator(tc.email)
|
||||||
|
if valid != tc.expectValid {
|
||||||
|
t.Fatalf("expected %v, got %v", tc.expectValid, valid)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
163
internal/sessions/cookie_store.go
Normal file
163
internal/sessions/cookie_store.go
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidSession is an error for invalid sessions.
|
||||||
|
var ErrInvalidSession = errors.New("invalid session")
|
||||||
|
|
||||||
|
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
||||||
|
type CSRFStore interface {
|
||||||
|
SetCSRF(http.ResponseWriter, *http.Request, string)
|
||||||
|
GetCSRF(*http.Request) (*http.Cookie, error)
|
||||||
|
ClearCSRF(http.ResponseWriter, *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||||
|
type SessionStore interface {
|
||||||
|
ClearSession(http.ResponseWriter, *http.Request)
|
||||||
|
LoadSession(*http.Request) (*SessionState, error)
|
||||||
|
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookieStore represents all the cookie related configurations
|
||||||
|
type CookieStore struct {
|
||||||
|
Name string
|
||||||
|
CSRFCookieName string
|
||||||
|
CookieExpire time.Duration
|
||||||
|
CookieRefresh time.Duration
|
||||||
|
CookieSecure bool
|
||||||
|
CookieHTTPOnly bool
|
||||||
|
CookieDomain string
|
||||||
|
CookieCipher aead.Cipher
|
||||||
|
SessionLifetimeTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
|
||||||
|
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
|
||||||
|
return func(s *CookieStore) error {
|
||||||
|
cipher, err := aead.NewMiscreantCipher(cookieSecret)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
|
||||||
|
}
|
||||||
|
s.CookieCipher = cipher
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||||
|
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
|
||||||
|
c := &CookieStore{
|
||||||
|
Name: cookieName,
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: true,
|
||||||
|
CookieExpire: 168 * time.Hour,
|
||||||
|
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range optFuncs {
|
||||||
|
err := f(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := c.CookieDomain
|
||||||
|
if domain == "" {
|
||||||
|
domain = "<default>"
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
domain := req.Host
|
||||||
|
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||||
|
domain = h
|
||||||
|
}
|
||||||
|
if s.CookieDomain != "" {
|
||||||
|
if !strings.HasSuffix(domain, s.CookieDomain) {
|
||||||
|
log.Warn().Str("cookie-domain", s.CookieDomain).Msg("using configured cookie domain")
|
||||||
|
}
|
||||||
|
domain = s.CookieDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
Domain: domain,
|
||||||
|
HttpOnly: s.CookieHTTPOnly,
|
||||||
|
Secure: s.CookieSecure,
|
||||||
|
Expires: now.Add(expiration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||||
|
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
return s.makeCookie(req, s.Name, value, expiration, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time.
|
||||||
|
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
return s.makeCookie(req, s.CSRFCookieName, value, expiration, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearCSRF clears the CSRF cookie from the request
|
||||||
|
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
||||||
|
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
|
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
||||||
|
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
||||||
|
return req.Cookie(s.CSRFCookieName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSession clears the session cookie from a request
|
||||||
|
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
|
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession returns a SessionState from the cookie in the request.
|
||||||
|
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
||||||
|
c, err := req.Cookie(s.Name)
|
||||||
|
if err != nil {
|
||||||
|
// always http.ErrNoCookie
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
session, err := UnmarshalSession(c.Value, s.CookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("remote-host", req.Host).Msg("error unmarshaling session")
|
||||||
|
return nil, ErrInvalidSession
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSession saves a session state to a request sessions.
|
||||||
|
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
|
||||||
|
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setSessionCookie(rw, req, value)
|
||||||
|
return nil
|
||||||
|
}
|
348
internal/sessions/cookie_store_test.go
Normal file
348
internal/sessions/cookie_store_test.go
Normal file
|
@ -0,0 +1,348 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
|
||||||
|
|
||||||
|
func TestCreateMiscreantCookieCipher(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cookieSecret []byte
|
||||||
|
expectedError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal case with base64 encoded secret",
|
||||||
|
cookieSecret: testEncodedCookieSecret,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "error when not base64 encoded",
|
||||||
|
cookieSecret: []byte("abcd"),
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := NewCookieStore("cookieName", CreateMiscreantCookieCipher(tc.cookieSecret))
|
||||||
|
if !tc.expectedError {
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
} else {
|
||||||
|
testutil.NotEqual(t, err, nil)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSession(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
optFuncs []func(*CookieStore) error
|
||||||
|
expectedError bool
|
||||||
|
expectedSession *CookieStore
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default with no opt funcs set",
|
||||||
|
expectedSession: &CookieStore{
|
||||||
|
Name: "cookieName",
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: true,
|
||||||
|
CookieExpire: 168 * time.Hour,
|
||||||
|
CSRFCookieName: "cookieName_csrf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opt func with an error returns an error",
|
||||||
|
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opt func overrides default values",
|
||||||
|
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
|
||||||
|
s.CookieExpire = time.Hour
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
expectedSession: &CookieStore{
|
||||||
|
Name: "cookieName",
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: true,
|
||||||
|
CookieExpire: time.Hour,
|
||||||
|
CSRFCookieName: "cookieName_csrf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore("cookieName", tc.optFuncs...)
|
||||||
|
if tc.expectedError {
|
||||||
|
testutil.NotEqual(t, err, nil)
|
||||||
|
} else {
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
}
|
||||||
|
testutil.Equal(t, tc.expectedSession, session)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMakeSessionCookie(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
cookieValue := "cookieValue"
|
||||||
|
expiration := time.Hour
|
||||||
|
cookieName := "cookieName"
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
optFuncs []func(*CookieStore) error
|
||||||
|
expectedCookie *http.Cookie
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default cookie domain",
|
||||||
|
expectedCookie: &http.Cookie{
|
||||||
|
Name: cookieName,
|
||||||
|
Value: cookieValue,
|
||||||
|
Path: "/",
|
||||||
|
Domain: "www.example.com",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: now.Add(expiration),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom cookie domain set",
|
||||||
|
optFuncs: []func(*CookieStore) error{
|
||||||
|
func(s *CookieStore) error {
|
||||||
|
s.CookieDomain = "buzzfeed.com"
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCookie: &http.Cookie{
|
||||||
|
Name: cookieName,
|
||||||
|
Value: cookieValue,
|
||||||
|
Path: "/",
|
||||||
|
Domain: "buzzfeed.com",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: now.Add(expiration),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||||
|
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
|
||||||
|
testutil.Equal(t, cookie, tc.expectedCookie)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMakeSessionCSRFCookie(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
cookieValue := "cookieValue"
|
||||||
|
expiration := time.Hour
|
||||||
|
cookieName := "cookieName"
|
||||||
|
csrfName := "cookieName_csrf"
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
optFuncs []func(*CookieStore) error
|
||||||
|
expectedCookie *http.Cookie
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default cookie domain",
|
||||||
|
expectedCookie: &http.Cookie{
|
||||||
|
Name: csrfName,
|
||||||
|
Value: cookieValue,
|
||||||
|
Path: "/",
|
||||||
|
Domain: "www.example.com",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: now.Add(expiration),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom cookie domain set",
|
||||||
|
optFuncs: []func(*CookieStore) error{
|
||||||
|
func(s *CookieStore) error {
|
||||||
|
s.CookieDomain = "buzzfeed.com"
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCookie: &http.Cookie{
|
||||||
|
Name: csrfName,
|
||||||
|
Value: cookieValue,
|
||||||
|
Path: "/",
|
||||||
|
Domain: "buzzfeed.com",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: now.Add(expiration),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||||
|
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
|
||||||
|
testutil.Equal(t, tc.expectedCookie, cookie)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSessionCookie(t *testing.T) {
|
||||||
|
cookieValue := "cookieValue"
|
||||||
|
cookieName := "cookieName"
|
||||||
|
|
||||||
|
t.Run("set session cookie test", func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
session.setSessionCookie(rw, req, cookieValue)
|
||||||
|
var found bool
|
||||||
|
for _, cookie := range rw.Result().Cookies() {
|
||||||
|
if cookie.Name == cookieName {
|
||||||
|
found = true
|
||||||
|
testutil.Equal(t, cookieValue, cookie.Value)
|
||||||
|
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testutil.Assert(t, found, "cookie in header")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func TestSetCSRFSessionCookie(t *testing.T) {
|
||||||
|
cookieValue := "cookieValue"
|
||||||
|
cookieName := "cookieName"
|
||||||
|
|
||||||
|
t.Run("set csrf cookie test", func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
session.SetCSRF(rw, req, cookieValue)
|
||||||
|
var found bool
|
||||||
|
for _, cookie := range rw.Result().Cookies() {
|
||||||
|
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||||
|
found = true
|
||||||
|
testutil.Equal(t, cookieValue, cookie.Value)
|
||||||
|
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testutil.Assert(t, found, "cookie in header")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearSessionCookie(t *testing.T) {
|
||||||
|
cookieValue := "cookieValue"
|
||||||
|
cookieName := "cookieName"
|
||||||
|
|
||||||
|
t.Run("set session cookie test", func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||||
|
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
session.ClearSession(rw, req)
|
||||||
|
var found bool
|
||||||
|
for _, cookie := range rw.Result().Cookies() {
|
||||||
|
if cookie.Name == cookieName {
|
||||||
|
found = true
|
||||||
|
testutil.Equal(t, "", cookie.Value)
|
||||||
|
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testutil.Assert(t, found, "cookie in header")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCSRFSessionCookie(t *testing.T) {
|
||||||
|
cookieValue := "cookieValue"
|
||||||
|
cookieName := "cookieName"
|
||||||
|
|
||||||
|
t.Run("clear csrf cookie test", func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||||
|
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
session.ClearCSRF(rw, req)
|
||||||
|
var found bool
|
||||||
|
for _, cookie := range rw.Result().Cookies() {
|
||||||
|
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||||
|
found = true
|
||||||
|
testutil.Equal(t, "", cookie.Value)
|
||||||
|
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testutil.Assert(t, found, "cookie in header")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadCookiedSession(t *testing.T) {
|
||||||
|
cookieName := "cookieName"
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
optFuncs []func(*CookieStore) error
|
||||||
|
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
|
||||||
|
expectedError error
|
||||||
|
sessionState *SessionState
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no cookie set returns an error",
|
||||||
|
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
|
||||||
|
expectedError: http.ErrNoCookie,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cookie set with cipher set",
|
||||||
|
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
|
||||||
|
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||||
|
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||||
|
},
|
||||||
|
sessionState: &SessionState{
|
||||||
|
Email: "example@email.com",
|
||||||
|
RefreshToken: "abccdddd",
|
||||||
|
AccessToken: "access",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cookie set with invalid value cipher set",
|
||||||
|
optFuncs: []func(*CookieStore) error{CreateMiscreantCookieCipher(testEncodedCookieSecret)},
|
||||||
|
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||||
|
value := "574b776a7c934d6b9fc42ec63a389f79"
|
||||||
|
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||||
|
},
|
||||||
|
expectedError: ErrInvalidSession,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||||
|
testutil.Ok(t, err)
|
||||||
|
req := httptest.NewRequest("GET", "https://www.example.com", nil)
|
||||||
|
tc.setupCookies(t, req, session, tc.sessionState)
|
||||||
|
s, err := session.LoadSession(req)
|
||||||
|
|
||||||
|
testutil.Equal(t, tc.expectedError, err)
|
||||||
|
testutil.Equal(t, tc.sessionState, s)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
50
internal/sessions/mock_store.go
Normal file
50
internal/sessions/mock_store.go
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCSRFStore is a mock implementation of the CSRF store interface
|
||||||
|
type MockCSRFStore struct {
|
||||||
|
ResponseCSRF string
|
||||||
|
Cookie *http.Cookie
|
||||||
|
GetError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCSRF sets the ResponseCSRF string to a val
|
||||||
|
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
|
ms.ResponseCSRF = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearCSRF clears the ResponseCSRF string
|
||||||
|
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||||
|
ms.ResponseCSRF = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCSRF returns the cookie and error
|
||||||
|
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||||
|
return ms.Cookie, ms.GetError
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||||
|
type MockSessionStore struct {
|
||||||
|
ResponseSession string
|
||||||
|
Session *SessionState
|
||||||
|
SaveError error
|
||||||
|
LoadError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSession clears the ResponseSession
|
||||||
|
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||||
|
ms.ResponseSession = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession returns the session and a error
|
||||||
|
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||||
|
return ms.Session, ms.LoadError
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSession returns a save error.
|
||||||
|
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||||
|
return ms.SaveError
|
||||||
|
}
|
70
internal/sessions/session_state.go
Normal file
70
internal/sessions/session_state.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrLifetimeExpired is an error for the lifetime deadline expiring
|
||||||
|
ErrLifetimeExpired = errors.New("user lifetime expired")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionState is our object that keeps track of a user's session state
|
||||||
|
type SessionState struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"` // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||||
|
|
||||||
|
RefreshDeadline time.Time `json:"refresh_deadline"`
|
||||||
|
LifetimeDeadline time.Time `json:"lifetime_deadline"`
|
||||||
|
ValidDeadline time.Time `json:"valid_deadline"`
|
||||||
|
GracePeriodStart time.Time `json:"grace_period_start"`
|
||||||
|
|
||||||
|
Email string `json:"email"`
|
||||||
|
User string `json:"user"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LifetimePeriodExpired returns true if the lifetime has expired
|
||||||
|
func (s *SessionState) LifetimePeriodExpired() bool {
|
||||||
|
return isExpired(s.LifetimeDeadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshPeriodExpired returns true if the refresh period has expired
|
||||||
|
func (s *SessionState) RefreshPeriodExpired() bool {
|
||||||
|
return isExpired(s.RefreshDeadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationPeriodExpired returns true if the validation period has expired
|
||||||
|
func (s *SessionState) ValidationPeriodExpired() bool {
|
||||||
|
return isExpired(s.ValidDeadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isExpired(t time.Time) bool {
|
||||||
|
return t.Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||||
|
// given cipher, and base64-encodes the result
|
||||||
|
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) {
|
||||||
|
return c.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||||
|
// byte slice using the pased cipher, and unmarshals the resulting JSON into a session state struct
|
||||||
|
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
|
||||||
|
s := &SessionState{}
|
||||||
|
err := c.Unmarshal(value, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtendDeadline returns the time extended by a given duration
|
||||||
|
func ExtendDeadline(ttl time.Duration) time.Time {
|
||||||
|
return time.Now().Add(ttl).Truncate(time.Second)
|
||||||
|
}
|
71
internal/sessions/session_state_test.go
Normal file
71
internal/sessions/session_state_test.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSessionStateSerialization(t *testing.T) {
|
||||||
|
secret := aead.GenerateKey()
|
||||||
|
c, err := aead.NewMiscreantCipher([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := &SessionState{
|
||||||
|
AccessToken: "token1234",
|
||||||
|
RefreshToken: "refresh4321",
|
||||||
|
|
||||||
|
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||||
|
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||||
|
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
|
||||||
|
|
||||||
|
Email: "user@domain.com",
|
||||||
|
User: "user",
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext, err := MarshalSession(want, c)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected to be encode session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := UnmarshalSession(ciphertext, c)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected to be decode session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Logf("want: %#v", want)
|
||||||
|
t.Logf(" got: %#v", got)
|
||||||
|
t.Errorf("encoding and decoding session resulted in unexpected output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStateExpirations(t *testing.T) {
|
||||||
|
session := &SessionState{
|
||||||
|
AccessToken: "token1234",
|
||||||
|
RefreshToken: "refresh4321",
|
||||||
|
|
||||||
|
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
|
||||||
|
RefreshDeadline: time.Now().Add(-1 * time.Hour),
|
||||||
|
ValidDeadline: time.Now().Add(-1 * time.Minute),
|
||||||
|
|
||||||
|
Email: "user@domain.com",
|
||||||
|
User: "user",
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session.LifetimePeriodExpired() {
|
||||||
|
t.Errorf("expcted lifetime period to be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session.RefreshPeriodExpired() {
|
||||||
|
t.Errorf("expcted lifetime period to be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session.ValidationPeriodExpired() {
|
||||||
|
t.Errorf("expcted lifetime period to be expired")
|
||||||
|
}
|
||||||
|
}
|
75
internal/singleflight/singleflight.go
Normal file
75
internal/singleflight/singleflight.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
// Original Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
//
|
||||||
|
// Modified by BuzzFeed to return duplicate counts.
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package singleflight provides a duplicate function call suppression mechanism.
|
||||||
|
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// call is an in-flight or completed singleflight.Do call
|
||||||
|
type call struct {
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// These fields are written once before the WaitGroup is done
|
||||||
|
// and are only read after the WaitGroup is done.
|
||||||
|
val interface{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
// These fields are read and written with the singleflight
|
||||||
|
// mutex held before the WaitGroup is done, and are read but
|
||||||
|
// not written after the WaitGroup is done.
|
||||||
|
dups int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group represents a class of work and forms a namespace in
|
||||||
|
// which units of work can be executed with duplicate suppression.
|
||||||
|
type Group struct {
|
||||||
|
mu sync.Mutex // protects m
|
||||||
|
m map[string]*call // lazily initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result holds the results of Do, so they can be passed
|
||||||
|
// on a channel.
|
||||||
|
type Result struct {
|
||||||
|
Val interface{}
|
||||||
|
Err error
|
||||||
|
Count bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do executes and returns the results of the given function, making
|
||||||
|
// sure that only one execution is in-flight for a given key at a
|
||||||
|
// time. If a duplicate comes in, the duplicate caller waits for the
|
||||||
|
// original to complete and receives the same results.
|
||||||
|
// The return value of Count indicates how many tiems v was given to multiple callers.
|
||||||
|
// Count will be zero for requests are shared and only be non-zero for the originating request.
|
||||||
|
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
|
||||||
|
g.mu.Lock()
|
||||||
|
|
||||||
|
if g.m == nil {
|
||||||
|
g.m = make(map[string]*call)
|
||||||
|
}
|
||||||
|
if c, ok := g.m[key]; ok {
|
||||||
|
c.dups++
|
||||||
|
g.mu.Unlock()
|
||||||
|
c.wg.Wait()
|
||||||
|
return c.val, 0, c.err
|
||||||
|
}
|
||||||
|
c := new(call)
|
||||||
|
c.wg.Add(1)
|
||||||
|
g.m[key] = c
|
||||||
|
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
c.val, c.err = fn()
|
||||||
|
c.wg.Done()
|
||||||
|
|
||||||
|
g.mu.Lock()
|
||||||
|
delete(g.m, key)
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
return c.val, c.dups, c.err
|
||||||
|
}
|
87
internal/singleflight/singleflight_test.go
Normal file
87
internal/singleflight/singleflight_test.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDo(t *testing.T) {
|
||||||
|
var g Group
|
||||||
|
v, _, err := g.Do("key", func() (interface{}, error) {
|
||||||
|
return "bar", nil
|
||||||
|
})
|
||||||
|
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
|
||||||
|
t.Errorf("Do = %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Do error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoErr(t *testing.T) {
|
||||||
|
var g Group
|
||||||
|
someErr := errors.New("Some error")
|
||||||
|
v, _, err := g.Do("key", func() (interface{}, error) {
|
||||||
|
return nil, someErr
|
||||||
|
})
|
||||||
|
if err != someErr {
|
||||||
|
t.Errorf("Do error = %v; want someErr %v", err, someErr)
|
||||||
|
}
|
||||||
|
if v != nil {
|
||||||
|
t.Errorf("unexpected non-nil value %#v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoDupSuppress(t *testing.T) {
|
||||||
|
var g Group
|
||||||
|
var wg1, wg2 sync.WaitGroup
|
||||||
|
c := make(chan string, 1)
|
||||||
|
var calls int32
|
||||||
|
fn := func() (interface{}, error) {
|
||||||
|
if atomic.AddInt32(&calls, 1) == 1 {
|
||||||
|
// First invocation.
|
||||||
|
wg1.Done()
|
||||||
|
}
|
||||||
|
v := <-c
|
||||||
|
c <- v // pump; make available for any future calls
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const n = 10
|
||||||
|
wg1.Add(1)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
wg1.Add(1)
|
||||||
|
wg2.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg2.Done()
|
||||||
|
wg1.Done()
|
||||||
|
v, _, err := g.Do("key", fn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Do error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s, _ := v.(string); s != "bar" {
|
||||||
|
t.Errorf("Do = %T %v; want %q", v, v, "bar")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg1.Wait()
|
||||||
|
// At least one goroutine is in fn now and all of them have at
|
||||||
|
// least reached the line before the Do.
|
||||||
|
c <- "bar"
|
||||||
|
wg2.Wait()
|
||||||
|
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
|
||||||
|
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
|
||||||
|
}
|
||||||
|
}
|
199
internal/templates/templates.go
Normal file
199
internal/templates/templates.go
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
package templates // import "github.com/pomerium/pomerium/internal/templates"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"html/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New loads html and style resources directly. Panics on failure.
|
||||||
|
func New() *template.Template {
|
||||||
|
t := template.New("authenticate-templates")
|
||||||
|
template.Must(t.Parse(`
|
||||||
|
{{define "header.html"}}
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-family: "Helvetica Neue",Helvetica,Arial,sans-serif;
|
||||||
|
font-size: 1em;
|
||||||
|
line-height: 1.42857143;
|
||||||
|
color: #333;
|
||||||
|
background: #f0f0f0;
|
||||||
|
}
|
||||||
|
p {
|
||||||
|
margin: 1.5em 0;
|
||||||
|
}
|
||||||
|
p:first-child {
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
p:last-child {
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
max-width: 40em;
|
||||||
|
display: block;
|
||||||
|
margin: 10% auto;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.content, .message, button {
|
||||||
|
border: 1px solid rgba(0,0,0,.125);
|
||||||
|
border-bottom-width: 4px;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.content, .message {
|
||||||
|
background-color: #fff;
|
||||||
|
padding: 2rem;
|
||||||
|
margin: 1rem 0;
|
||||||
|
}
|
||||||
|
.error, .message {
|
||||||
|
border-bottom-color: #c00;
|
||||||
|
}
|
||||||
|
.message {
|
||||||
|
padding: 1.5rem 2rem 1.3rem;
|
||||||
|
}
|
||||||
|
header {
|
||||||
|
border-bottom: 1px solid rgba(0,0,0,.075);
|
||||||
|
margin: -2rem 0 2rem;
|
||||||
|
padding: 2rem 0 1.8rem;
|
||||||
|
}
|
||||||
|
header h1 {
|
||||||
|
font-size: 1.5em;
|
||||||
|
font-weight: normal;
|
||||||
|
}
|
||||||
|
.error header {
|
||||||
|
color: #c00;
|
||||||
|
}
|
||||||
|
.details {
|
||||||
|
font-size: .85rem;
|
||||||
|
color: #999;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
color: #fff;
|
||||||
|
background-color: #3B8686;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: bold;
|
||||||
|
padding: 1rem 2.5rem;
|
||||||
|
text-shadow: 0 3px 1px rgba(0,0,0,.2);
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
button:active {
|
||||||
|
border-top-width: 4px;
|
||||||
|
border-bottom-width: 1px;
|
||||||
|
text-shadow: none;
|
||||||
|
}
|
||||||
|
footer {
|
||||||
|
font-size: 0.75em;
|
||||||
|
color: #999;
|
||||||
|
text-align: right;
|
||||||
|
margin: 1rem;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
{{end}}`))
|
||||||
|
|
||||||
|
t = template.Must(t.Parse(`{{define "footer.html"}}Secured by <b>pomerium</b> {{end}}`))
|
||||||
|
|
||||||
|
t = template.Must(t.Parse(`
|
||||||
|
{{define "sign_in_message.html"}}
|
||||||
|
{{if eq (len .EmailDomains) 1}}
|
||||||
|
{{if eq (index .EmailDomains 0) "@*"}}
|
||||||
|
<p>You may sign in with any {{.ProviderName}} account.</p>
|
||||||
|
{{else}}
|
||||||
|
<p>You may sign in with your <b>{{index .EmailDomains 0}}</b> {{.ProviderName}} account.</p>
|
||||||
|
{{end}}
|
||||||
|
{{else if gt (len .EmailDomains) 1}}
|
||||||
|
<p>
|
||||||
|
You may sign in with any of these {{.ProviderName}} accounts:<br>
|
||||||
|
{{range $i, $e := .EmailDomains}}{{if $i}}, {{end}}<b>{{$e}}</b>{{end}}
|
||||||
|
</p>
|
||||||
|
{{end}}
|
||||||
|
{{end}}`))
|
||||||
|
|
||||||
|
t = template.Must(t.Parse(`
|
||||||
|
{{define "sign_in.html"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en" charset="utf-8">
|
||||||
|
<head>
|
||||||
|
<title>Sign In</title>
|
||||||
|
{{template "header.html"}}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="content">
|
||||||
|
<header>
|
||||||
|
<h1>Sign in to <b>{{.Destination}}</b></h1>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{{template "sign_in_message.html" .}}
|
||||||
|
|
||||||
|
<form method="GET" action="/start">
|
||||||
|
<input type="hidden" name="redirect_uri" value="{{.Redirect}}">
|
||||||
|
<button type="submit" class="btn">Sign in with {{.ProviderName}}</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<footer>{{template "footer.html"}} </br> {{.Version}} </footer>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}`))
|
||||||
|
|
||||||
|
template.Must(t.Parse(`
|
||||||
|
{{define "error.html"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en" charset="utf-8">
|
||||||
|
<head>
|
||||||
|
<title>Error</title>
|
||||||
|
{{template "header.html"}}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="content error">
|
||||||
|
<header>
|
||||||
|
<h1>{{.Title}}</h1>
|
||||||
|
</header>
|
||||||
|
<p>
|
||||||
|
{{.Message}}<br>
|
||||||
|
<span class="details">HTTP {{.Code}}</span>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<footer>{{template "footer.html"}} </br> {{.Version}} </footer>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>{{end}}`))
|
||||||
|
|
||||||
|
t = template.Must(t.Parse(`
|
||||||
|
{{define "sign_out.html"}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en" charset="utf-8">
|
||||||
|
<head>
|
||||||
|
<title>Sign Out</title>
|
||||||
|
{{template "header.html"}}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
{{ if .Message }}
|
||||||
|
<div class="message">{{.Message}}</div>
|
||||||
|
{{ end}}
|
||||||
|
<div class="content">
|
||||||
|
<header>
|
||||||
|
<h1>Sign out of <b>{{.Destination}}</b></h1>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<p>You're currently signed in as <b>{{.Email}}</b>. This will also sign you out of other internal apps.</p>
|
||||||
|
<form method="POST" action="/sign_out">
|
||||||
|
<input type="hidden" name="redirect_uri" value="{{.Redirect}}">
|
||||||
|
<input type="hidden" name="sig" value="{{.Signature}}">
|
||||||
|
<input type="hidden" name="ts" value="{{.Timestamp}}">
|
||||||
|
<button type="submit">Sign out</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
<footer>{{template "footer.html"}} </br> {{.Version}}</footer>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{end}}`))
|
||||||
|
return t
|
||||||
|
}
|
12
internal/templates/templates_test.go
Normal file
12
internal/templates/templates_test.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package templates // import "github.com/pomerium/pomerium/internal/templates"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTemplatesCompile(t *testing.T) {
|
||||||
|
templates := New()
|
||||||
|
testutil.NotEqual(t, templates, nil)
|
||||||
|
}
|
46
internal/testutil/testutil.go
Normal file
46
internal/testutil/testutil.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
|
||||||
|
// testing util functions copied from https://github.com/benbjohnson/testing
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Assert fails the test if the condition is false.
|
||||||
|
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
|
||||||
|
if !condition {
|
||||||
|
_, file, line, _ := runtime.Caller(1)
|
||||||
|
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
|
||||||
|
tb.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ok fails the test if an err is not nil.
|
||||||
|
func Ok(tb testing.TB, err error) {
|
||||||
|
if err != nil {
|
||||||
|
_, file, line, _ := runtime.Caller(1)
|
||||||
|
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
|
||||||
|
tb.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal fails the test if exp is not equal to act.
|
||||||
|
func Equal(tb testing.TB, exp, act interface{}) {
|
||||||
|
if !reflect.DeepEqual(exp, act) {
|
||||||
|
_, file, line, _ := runtime.Caller(1)
|
||||||
|
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
|
||||||
|
tb.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotEqual fails the test if exp is equal to act.
|
||||||
|
func NotEqual(tb testing.TB, exp, act interface{}) {
|
||||||
|
if reflect.DeepEqual(exp, act) {
|
||||||
|
_, file, line, _ := runtime.Caller(1)
|
||||||
|
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
|
||||||
|
tb.FailNow()
|
||||||
|
}
|
||||||
|
}
|
42
internal/version/version.go
Normal file
42
internal/version/version.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package version // import "github.com/pomerium/pomerium/internal/version"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ProjectName is the canonical project name set by ldl flags
|
||||||
|
ProjectName = ""
|
||||||
|
// ProjectURL is the canonical project url set by ldl flags
|
||||||
|
ProjectURL = ""
|
||||||
|
// Version specifies Semantic versioning increment (MAJOR.MINOR.PATCH).
|
||||||
|
Version = "v0.0.0"
|
||||||
|
// GitCommit specifies the git commit sha, set by the compiler.
|
||||||
|
GitCommit = ""
|
||||||
|
// BuildMeta specifies release type (dev,rc1,beta,etc)
|
||||||
|
BuildMeta = ""
|
||||||
|
|
||||||
|
runtimeVersion = runtime.Version()
|
||||||
|
)
|
||||||
|
|
||||||
|
// FullVersion returns a version string.
|
||||||
|
func FullVersion() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.Grow(len(Version) + len(GitCommit) + len(BuildMeta) + len("-") + len("+"))
|
||||||
|
sb.WriteString(Version)
|
||||||
|
if BuildMeta != "" {
|
||||||
|
sb.WriteString("-" + BuildMeta)
|
||||||
|
}
|
||||||
|
if GitCommit != "" {
|
||||||
|
sb.WriteString("+" + GitCommit)
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAgent returns a user-agent string as specified in RFC 2616:14.43
|
||||||
|
// https://tools.ietf.org/html/rfc2616
|
||||||
|
func UserAgent() string {
|
||||||
|
return fmt.Sprintf("%s/%s (+%s; %s; %s)", ProjectName, Version, ProjectURL, GitCommit, runtimeVersion)
|
||||||
|
}
|
71
internal/version/version_test.go
Normal file
71
internal/version/version_test.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package version // import "github.com/pomerium/pomerium/internal/version"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFullVersionVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
Version string
|
||||||
|
GitCommit string
|
||||||
|
BuildMeta string
|
||||||
|
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"", "", "", ""},
|
||||||
|
{"1.0.0", "", "", "1.0.0"},
|
||||||
|
{"1.0.0", "314501b", "", "1.0.0+314501b"},
|
||||||
|
{"1.0.0", "314501b", "dev", "1.0.0-dev+314501b"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
Version = tt.Version
|
||||||
|
GitCommit = tt.GitCommit
|
||||||
|
BuildMeta = tt.BuildMeta
|
||||||
|
|
||||||
|
if got := FullVersion(); got != tt.expected {
|
||||||
|
t.Errorf("expected (%s) got (%s) for Version(%s), GitCommit(%s) BuildMeta(%s)",
|
||||||
|
tt.expected,
|
||||||
|
got,
|
||||||
|
tt.Version,
|
||||||
|
tt.GitCommit,
|
||||||
|
tt.BuildMeta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func BenchmarkFullVersion(b *testing.B) {
|
||||||
|
Version = "1.0.0"
|
||||||
|
GitCommit = "314501b"
|
||||||
|
BuildMeta = "dev"
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
FullVersion()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserAgent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
Version string
|
||||||
|
GitCommit string
|
||||||
|
BuildMeta string
|
||||||
|
ProjectName string
|
||||||
|
ProjectURL string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"good user agent", "1.0.0", "314501b", "dev", "pomerium", "github.com/pomerium", fmt.Sprintf("pomerium/1.0.0 (+github.com/pomerium; 314501b; %s)", runtime.Version())},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
Version = tt.Version
|
||||||
|
GitCommit = tt.GitCommit
|
||||||
|
BuildMeta = tt.BuildMeta
|
||||||
|
ProjectName = tt.ProjectName
|
||||||
|
ProjectURL = tt.ProjectURL
|
||||||
|
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := UserAgent(); got != tt.want {
|
||||||
|
t.Errorf("UserAgent() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
317
proxy/authenticator/authenticator.go
Normal file
317
proxy/authenticator/authenticator.go
Normal file
|
@ -0,0 +1,317 @@
|
||||||
|
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultHTTPClient = &http.Client{
|
||||||
|
Timeout: time.Second * 5,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}).Dial,
|
||||||
|
TLSHandshakeTimeout: 2 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errors
|
||||||
|
var (
|
||||||
|
ErrMissingRefreshToken = errors.New("missing refresh token")
|
||||||
|
ErrAuthProviderUnavailable = errors.New("auth provider unavailable")
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticateClient holds the data associated with the AuthenticateClients
|
||||||
|
// necessary to implement a AuthenticateClient interface.
|
||||||
|
type AuthenticateClient struct {
|
||||||
|
AuthenticateServiceURL *url.URL
|
||||||
|
|
||||||
|
//
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
SignInURL *url.URL
|
||||||
|
SignOutURL *url.URL
|
||||||
|
RedeemURL *url.URL
|
||||||
|
RefreshURL *url.URL
|
||||||
|
ProfileURL *url.URL
|
||||||
|
ValidateURL *url.URL
|
||||||
|
|
||||||
|
SessionValidTTL time.Duration
|
||||||
|
SessionLifetimeTTL time.Duration
|
||||||
|
GracePeriodTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthenticateClient instantiates a new AuthenticateClient with provider data
|
||||||
|
func NewAuthenticateClient(uri *url.URL, clientID, clientSecret string, sessionValid, sessionLifetime, gracePeriod time.Duration) *AuthenticateClient {
|
||||||
|
return &AuthenticateClient{
|
||||||
|
AuthenticateServiceURL: uri,
|
||||||
|
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
|
||||||
|
SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}),
|
||||||
|
SignOutURL: uri.ResolveReference(&url.URL{Path: "/sign_out"}),
|
||||||
|
RedeemURL: uri.ResolveReference(&url.URL{Path: "/redeem"}),
|
||||||
|
RefreshURL: uri.ResolveReference(&url.URL{Path: "/refresh"}),
|
||||||
|
ValidateURL: uri.ResolveReference(&url.URL{Path: "/validate"}),
|
||||||
|
ProfileURL: uri.ResolveReference(&url.URL{Path: "/profile"}),
|
||||||
|
|
||||||
|
SessionValidTTL: sessionValid,
|
||||||
|
SessionLifetimeTTL: sessionLifetime,
|
||||||
|
GracePeriodTTL: gracePeriod,
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AuthenticateClient) newRequest(method, url string, body io.Reader) (*http.Request, error) {
|
||||||
|
req, err := http.NewRequest(method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", version.UserAgent())
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Host = p.AuthenticateServiceURL.Host
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProviderUnavailable(statusCode int) bool {
|
||||||
|
return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
func extendDeadline(ttl time.Duration) time.Time {
|
||||||
|
return time.Now().Add(ttl).Truncate(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AuthenticateClient) withinGracePeriod(s *sessions.SessionState) bool {
|
||||||
|
if s.GracePeriodStart.IsZero() {
|
||||||
|
s.GracePeriodStart = time.Now()
|
||||||
|
}
|
||||||
|
return s.GracePeriodStart.Add(p.GracePeriodTTL).After(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem takes a redirectURL and code and redeems the SessionState
|
||||||
|
func (p *AuthenticateClient) Redeem(redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("missing code")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := url.Values{}
|
||||||
|
// params that are validates by the authenticate service middleware
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
params.Add("client_secret", p.ClientSecret)
|
||||||
|
params.Add("code", code)
|
||||||
|
|
||||||
|
req, err := p.newRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
resp, err := defaultHTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
if isProviderUnavailable(resp.StatusCode) {
|
||||||
|
return nil, ErrAuthProviderUnavailable
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jsonResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(body, &jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user := strings.Split(jsonResponse.Email, "@")[0]
|
||||||
|
return &sessions.SessionState{
|
||||||
|
AccessToken: jsonResponse.AccessToken,
|
||||||
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
|
IDToken: jsonResponse.IDToken,
|
||||||
|
|
||||||
|
RefreshDeadline: extendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second),
|
||||||
|
LifetimeDeadline: extendDeadline(p.SessionLifetimeTTL),
|
||||||
|
ValidDeadline: extendDeadline(p.SessionValidTTL),
|
||||||
|
|
||||||
|
Email: jsonResponse.Email,
|
||||||
|
User: user,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSession refreshes the current session
|
||||||
|
func (p *AuthenticateClient) RefreshSession(s *sessions.SessionState) (bool, error) {
|
||||||
|
|
||||||
|
if s.RefreshToken == "" {
|
||||||
|
return false, ErrMissingRefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
// When we detect that the auth provider is not explicitly denying
|
||||||
|
// authentication, and is merely unavailable, we refresh and continue
|
||||||
|
// as normal during the "grace period"
|
||||||
|
if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) {
|
||||||
|
s.RefreshDeadline = extendDeadline(p.SessionValidTTL)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.AccessToken = newToken
|
||||||
|
s.RefreshDeadline = extendDeadline(duration)
|
||||||
|
s.GracePeriodStart = time.Time{}
|
||||||
|
log.Info().Str("user", s.Email).Msg("proxy/authenticator.RefreshSession")
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AuthenticateClient) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
params.Add("client_secret", p.ClientSecret)
|
||||||
|
params.Add("refresh_token", refreshToken)
|
||||||
|
var req *http.Request
|
||||||
|
req, err = p.newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
resp, err := defaultHTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var body []byte
|
||||||
|
body, err = ioutil.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
if isProviderUnavailable(resp.StatusCode) {
|
||||||
|
err = ErrAuthProviderUnavailable
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RefreshURL.String(), body)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(body, &data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token = data.AccessToken
|
||||||
|
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSessionState validates the current sessions state
|
||||||
|
func (p *AuthenticateClient) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
|
// we validate the user's access token is valid
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
req, err := p.newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : error validating session state")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Client-Secret", p.ClientSecret)
|
||||||
|
req.Header.Set("X-Access-Token", s.AccessToken)
|
||||||
|
req.Header.Set("X-Id-Token", s.IDToken)
|
||||||
|
|
||||||
|
resp, err := defaultHTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : error making request to validate access token")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
// When we detect that the auth provider is not explicitly denying
|
||||||
|
// authentication, and is merely unavailable, we validate and continue
|
||||||
|
// as normal during the "grace period"
|
||||||
|
if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) {
|
||||||
|
//tags := []string{"action:validate_session", "error:validation_failed"}
|
||||||
|
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
log.Info().Str("user", s.Email).Int("status-code", resp.StatusCode).Msg("proxy/authenticator.ValidateSessionState : could not validate user access token")
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
|
||||||
|
s.GracePeriodStart = time.Time{}
|
||||||
|
|
||||||
|
log.Info().Str("user", s.Email).Msg("proxy/authenticator.ValidateSessionState : validated session")
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// signRedirectURL signs the redirect url string, given a timestamp, and returns it
|
||||||
|
func (p *AuthenticateClient) signRedirectURL(rawRedirect string, timestamp time.Time) string {
|
||||||
|
h := hmac.New(sha256.New, []byte(p.ClientSecret))
|
||||||
|
h.Write([]byte(rawRedirect))
|
||||||
|
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||||
|
return base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignInURL with typical oauth parameters
|
||||||
|
func (p *AuthenticateClient) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
|
||||||
|
a := *p.SignInURL
|
||||||
|
now := time.Now()
|
||||||
|
rawRedirect := redirectURL.String()
|
||||||
|
params, _ := url.ParseQuery(a.RawQuery)
|
||||||
|
params.Set("redirect_uri", rawRedirect)
|
||||||
|
params.Set("client_id", p.ClientID)
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Add("state", state)
|
||||||
|
params.Set("ts", fmt.Sprint(now.Unix()))
|
||||||
|
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
||||||
|
a.RawQuery = params.Encode()
|
||||||
|
return &a
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
|
||||||
|
func (p *AuthenticateClient) GetSignOutURL(redirectURL *url.URL) *url.URL {
|
||||||
|
a := *p.SignOutURL
|
||||||
|
now := time.Now()
|
||||||
|
rawRedirect := redirectURL.String()
|
||||||
|
params, _ := url.ParseQuery(a.RawQuery)
|
||||||
|
params.Add("redirect_uri", rawRedirect)
|
||||||
|
params.Set("ts", fmt.Sprint(now.Unix()))
|
||||||
|
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
||||||
|
a.RawQuery = params.Encode()
|
||||||
|
return &a
|
||||||
|
}
|
520
proxy/handlers.go
Normal file
520
proxy/handlers.go
Normal file
|
@ -0,0 +1,520 @@
|
||||||
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
const loggingUserHeader = "SSO-Authenticated-User"
|
||||||
|
|
||||||
|
var (
|
||||||
|
//ErrUserNotAuthorized is set when user is not authorized to access a resource
|
||||||
|
ErrUserNotAuthorized = errors.New("user not authorized")
|
||||||
|
)
|
||||||
|
|
||||||
|
var securityHeaders = map[string]string{
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"X-Frame-Options": "SAMEORIGIN",
|
||||||
|
"X-XSS-Protection": "1; mode=block",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns a http handler for an Proxy
|
||||||
|
func (p *Proxy) Handler() http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/favicon.ico", p.Favicon)
|
||||||
|
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||||
|
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
||||||
|
mux.HandleFunc("/.pomerium/callback", p.OAuthCallback)
|
||||||
|
mux.HandleFunc("/.pomerium/auth", p.AuthenticateOnly)
|
||||||
|
mux.HandleFunc("/", p.Proxy)
|
||||||
|
|
||||||
|
// Global middleware, which will be applied to each request in reverse
|
||||||
|
// order as applied here (i.e., we want to validate the host _first_ when
|
||||||
|
// processing a request)
|
||||||
|
var handler http.Handler = mux
|
||||||
|
// todo(bdd) : investigate if setting non-overridable headers makes sense
|
||||||
|
// handler = p.setResponseHeaderOverrides(handler)
|
||||||
|
handler = middleware.SetHeaders(handler, securityHeaders)
|
||||||
|
handler = middleware.ValidateHost(handler, p.mux)
|
||||||
|
handler = middleware.RequireHTTPS(handler)
|
||||||
|
handler = log.NewLoggingHandler(handler)
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// Skip host validation for /ping requests because they hit the LB directly.
|
||||||
|
if req.URL.Path == "/ping" {
|
||||||
|
p.PingPage(rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
||||||
|
func (p *Proxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Favicon will proxy the request as usual if the user is already authenticated
|
||||||
|
// but responds with a 404 otherwise, to avoid spurious and confusing
|
||||||
|
// authentication attempts when a browser automatically requests the favicon on
|
||||||
|
// an error page.
|
||||||
|
func (p *Proxy) Favicon(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := p.Authenticate(rw, req)
|
||||||
|
if err != nil {
|
||||||
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.Proxy(rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingPage send back a 200 OK response.
|
||||||
|
func (p *Proxy) PingPage(rw http.ResponseWriter, _ *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(rw, "OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOut redirects the request to the sign out url.
|
||||||
|
func (p *Proxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
|
||||||
|
var scheme string
|
||||||
|
|
||||||
|
// Build redirect URI from request host
|
||||||
|
if req.URL.Scheme == "" {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := &url.URL{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: req.Host,
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
fullURL := p.authenticateClient.GetSignOutURL(redirectURL)
|
||||||
|
http.Redirect(rw, req, fullURL.String(), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// XHRError returns a simple error response with an error message to the application if the request is an XML request
|
||||||
|
func (p *Proxy) XHRError(rw http.ResponseWriter, req *http.Request, code int, err error) {
|
||||||
|
jsonError := struct {
|
||||||
|
Error error `json:"error"`
|
||||||
|
}{
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(jsonError)
|
||||||
|
if err != nil {
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requestLog := log.WithRequest(req, "proxy.ErrorPage")
|
||||||
|
requestLog.Error().Err(err).Int("http-status", code).Msg("proxy.XHRError")
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
rw.Write(jsonBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorPage renders an error page with a given status code, title, and message.
|
||||||
|
func (p *Proxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) {
|
||||||
|
if p.isXHR(req) {
|
||||||
|
p.XHRError(rw, req, code, errors.New(message))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requestLog := log.WithRequest(req, "proxy.ErrorPage")
|
||||||
|
requestLog.Info().
|
||||||
|
Str("page-title", title).
|
||||||
|
Str("page-message", message).
|
||||||
|
Msg("proxy.ErrorPage")
|
||||||
|
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
t := struct {
|
||||||
|
Code int
|
||||||
|
Title string
|
||||||
|
Message string
|
||||||
|
Version string
|
||||||
|
}{
|
||||||
|
Code: code,
|
||||||
|
Title: title,
|
||||||
|
Message: message,
|
||||||
|
Version: version.FullVersion(),
|
||||||
|
}
|
||||||
|
p.templates.ExecuteTemplate(rw, "error.html", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) isXHR(req *http.Request) bool {
|
||||||
|
return req.Header.Get("X-Requested-With") == "XMLHttpRequest"
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthStart begins the authentication flow, encrypting the redirect url
|
||||||
|
// in a request to the provider's sign in endpoint.
|
||||||
|
func (p *Proxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// The proxy redirects to the authenticator, and provides it with redirectURI (which points
|
||||||
|
// back to the sso proxy).
|
||||||
|
requestLog := log.WithRequest(req, "proxy.OAuthStart")
|
||||||
|
|
||||||
|
if p.isXHR(req) {
|
||||||
|
e := errors.New("cannot continue oauth flow on xhr")
|
||||||
|
requestLog.Error().Err(e).Msg("isXHR")
|
||||||
|
p.XHRError(rw, req, http.StatusUnauthorized, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestURI := req.URL.String()
|
||||||
|
callbackURL := p.GetRedirectURL(req.Host)
|
||||||
|
|
||||||
|
// generate nonce
|
||||||
|
key := aead.GenerateKey()
|
||||||
|
|
||||||
|
// state prevents cross site forgery and maintain state across the client and server
|
||||||
|
state := &StateParameter{
|
||||||
|
SessionID: fmt.Sprintf("%x", key), // nonce
|
||||||
|
RedirectURI: requestURI, // where to redirect the user back to
|
||||||
|
}
|
||||||
|
|
||||||
|
// we encrypt this value to be opaque the browser cookie
|
||||||
|
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||||
|
encryptedCSRF, err := p.CookieCipher.Marshal(state)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("failed to marshal csrf")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.csrfStore.SetCSRF(rw, req, encryptedCSRF)
|
||||||
|
|
||||||
|
// we encrypt this value to be opaque the uri query value
|
||||||
|
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||||
|
encryptedState, err := p.CookieCipher.Marshal(state)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("failed to encrypt cookie")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
signinURL := p.authenticateClient.GetSignInURL(callbackURL, encryptedState)
|
||||||
|
requestLog.Info().Msg("redirecting to begin auth flow")
|
||||||
|
http.Redirect(rw, req, signinURL.String(), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthCallback validates the cookie sent back from the provider, then validates
|
||||||
|
// the user information, and if authorized, redirects the user back to the original
|
||||||
|
// application.
|
||||||
|
func (p *Proxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// We receive the callback from the SSO Authenticator. This request will either contain an
|
||||||
|
// error, or it will contain a `code`; the code can be used to fetch an access token, and
|
||||||
|
// other metadata, from the authenticator.
|
||||||
|
requestLog := log.WithRequest(req, "proxy.OAuthCallback")
|
||||||
|
// finish the oauth cycle
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("failed parsing request form")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errorString := req.Form.Get("error")
|
||||||
|
if errorString != "" {
|
||||||
|
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorString)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We begin the process of redeeming the code for an access token.
|
||||||
|
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("error redeeming authorization code")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedState := req.Form.Get("state")
|
||||||
|
stateParameter := &StateParameter{}
|
||||||
|
err = p.CookieCipher.Unmarshal(encryptedState, stateParameter)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("could not unmarshal state")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := req.Cookie(p.CSRFCookieName)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("failed parsing csrf cookie")
|
||||||
|
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.csrfStore.ClearCSRF(rw, req)
|
||||||
|
|
||||||
|
encryptedCSRF := c.Value
|
||||||
|
csrfParameter := &StateParameter{}
|
||||||
|
err = p.CookieCipher.Unmarshal(encryptedCSRF, csrfParameter)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Msg("couldn't unmarshal CSRF")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if encryptedState == encryptedCSRF {
|
||||||
|
requestLog.Error().Msg("encrypted state and CSRF should not be equal")
|
||||||
|
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
||||||
|
requestLog.Error().Msg("state and CSRF should be equal")
|
||||||
|
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We validate the user information, and check that this user has proper authorization
|
||||||
|
// for the resources requested. This can be set via the email address or any groups.
|
||||||
|
//
|
||||||
|
// set cookie, or deny
|
||||||
|
if !p.EmailValidator(session.Email) {
|
||||||
|
requestLog.Error().Str("user", session.Email).Msg("permission denied: unauthorized")
|
||||||
|
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "Invalid Account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We store the session in a cookie and redirect the user back to the application
|
||||||
|
err = p.sessionStore.SaveSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Msg("error saving session")
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the redirect back to the original requested application
|
||||||
|
http.Redirect(rw, req, stateParameter.RedirectURI, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateOnly calls the Authenticate handler.
|
||||||
|
func (p *Proxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
err := p.Authenticate(rw, req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy authenticates a request, either proxying the request if it is authenticated, or starting the authentication process if not.
|
||||||
|
func (p *Proxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// Attempts to validate the user and their cookie.
|
||||||
|
// start := time.Now()
|
||||||
|
var err error
|
||||||
|
err = p.Authenticate(rw, req)
|
||||||
|
// If the authentication is not successful we proceed to start the OAuth Flow with
|
||||||
|
// OAuthStart. If authentication is successful, we proceed to proxy to the configured
|
||||||
|
// upstream.
|
||||||
|
requestLog := log.WithRequest(req, "proxy.Proxy")
|
||||||
|
if err != nil {
|
||||||
|
switch err {
|
||||||
|
case http.ErrNoCookie:
|
||||||
|
// No cookie is set, start the oauth flow
|
||||||
|
p.OAuthStart(rw, req)
|
||||||
|
return
|
||||||
|
case ErrUserNotAuthorized:
|
||||||
|
// We know the user is not authorized for the request, we show them a forbidden page
|
||||||
|
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
|
||||||
|
return
|
||||||
|
case sessions.ErrLifetimeExpired:
|
||||||
|
// User's lifetime expired, we trigger the start of the oauth flow
|
||||||
|
p.OAuthStart(rw, req)
|
||||||
|
return
|
||||||
|
case sessions.ErrInvalidSession:
|
||||||
|
// The user session is invalid and we can't decode it.
|
||||||
|
// This can happen for a variety of reasons but the most common non-malicious
|
||||||
|
// case occurs when the session encoding schema changes. We manage this ux
|
||||||
|
// by triggering the start of the oauth flow.
|
||||||
|
p.OAuthStart(rw, req)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
requestLog.Error().Err(err).Msg("unknown error")
|
||||||
|
// We don't know exactly what happened, but authenticating the user failed, show an error
|
||||||
|
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "An unexpected error occurred")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have validated the users request and now proxy their request to the provided upstream.
|
||||||
|
route, ok := p.router(req)
|
||||||
|
if !ok {
|
||||||
|
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// overhead := time.Now().Sub(start)
|
||||||
|
route.ServeHTTP(rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
|
||||||
|
// clearing the session cookie if it's invalid and returning an error if necessary..
|
||||||
|
func (p *Proxy) Authenticate(rw http.ResponseWriter, req *http.Request) (err error) {
|
||||||
|
|
||||||
|
// Clear the session cookie if anything goes wrong.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
p.sessionStore.ClearSession(rw, req)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
requestLog := log.WithRequest(req, "proxy.Authenticate")
|
||||||
|
|
||||||
|
session, err := p.sessionStore.LoadSession(req)
|
||||||
|
if err != nil {
|
||||||
|
// We loaded a cookie but it wasn't valid, clear it, and reject the request
|
||||||
|
requestLog.Error().Err(err).Msg("error authenticating user")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lifetime period is the entire duration in which the session is valid.
|
||||||
|
// This should be set to something like 14 to 30 days.
|
||||||
|
if session.LifetimePeriodExpired() {
|
||||||
|
requestLog.Warn().Str("user", session.Email).Msg("session lifetime has expired")
|
||||||
|
return sessions.ErrLifetimeExpired
|
||||||
|
} else if session.RefreshPeriodExpired() {
|
||||||
|
// Refresh period is the period in which the access token is valid. This is ultimately
|
||||||
|
// controlled by the upstream provider and tends to be around 1 hour.
|
||||||
|
ok, err := p.authenticateClient.RefreshSession(session)
|
||||||
|
|
||||||
|
// We failed to refresh the session successfully
|
||||||
|
// clear the cookie and reject the request
|
||||||
|
if err != nil {
|
||||||
|
requestLog.Error().Err(err).Str("user", session.Email).Msg("refreshing session failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
// User is not authorized after refresh
|
||||||
|
// clear the cookie and reject the request
|
||||||
|
requestLog.Error().Str("user", session.Email).Msg("not authorized after refreshing session")
|
||||||
|
return ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.sessionStore.SaveSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
// We refreshed the session successfully, but failed to save it.
|
||||||
|
//
|
||||||
|
// This could be from failing to encode the session properly.
|
||||||
|
// But, we clear the session cookie and reject the request!
|
||||||
|
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save refresh session")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if session.ValidationPeriodExpired() {
|
||||||
|
// Validation period has expired, this is the shortest interval we use to
|
||||||
|
// check for valid requests. This should be set to something like a minute.
|
||||||
|
// This calls up the provider chain to validate this user is still active
|
||||||
|
// and hasn't been de-authorized.
|
||||||
|
ok := p.authenticateClient.ValidateSessionState(session)
|
||||||
|
if !ok {
|
||||||
|
// This user is now no longer authorized, or we failed to
|
||||||
|
// validate the user.
|
||||||
|
// Clear the cookie and reject the request
|
||||||
|
requestLog.Error().Str("user", session.Email).Msg("no longer authorized after validation period")
|
||||||
|
return ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.sessionStore.SaveSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
// We validated the session successfully, but failed to save it.
|
||||||
|
|
||||||
|
// This could be from failing to encode the session properly.
|
||||||
|
// But, we clear the session cookie and reject the request!
|
||||||
|
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save validated session")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.EmailValidator(session.Email) {
|
||||||
|
requestLog.Error().Str("user", session.Email).Msg("email failed to validate, unauthorized")
|
||||||
|
return ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("X-Forwarded-User", session.User)
|
||||||
|
|
||||||
|
if p.PassAccessToken && session.AccessToken != "" {
|
||||||
|
req.Header.Set("X-Forwarded-Access-Token", session.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("X-Forwarded-Email", session.Email)
|
||||||
|
req.Header.Set("X-Forwarded-Groups", strings.Join(session.Groups, ","))
|
||||||
|
|
||||||
|
// stash authenticated user so that it can be logged later (see func logRequest)
|
||||||
|
rw.Header().Set(loggingUserHeader, session.Email)
|
||||||
|
|
||||||
|
// This user has been OK'd. Allow the request!
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamTransport is used to ensure that upstreams cannot override the
|
||||||
|
// security headers applied by sso_proxy
|
||||||
|
type upstreamTransport struct {
|
||||||
|
transport *http.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip round trips the request and deletes security headers before returning the response.
|
||||||
|
func (t *upstreamTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := t.transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("proxy.RoundTrip")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for key := range securityHeaders {
|
||||||
|
resp.Header.Del(key)
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
|
||||||
|
func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||||
|
p.mux[host] = &handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// router attempts to find a route for a request. If a route is successfully matched,
|
||||||
|
// it returns the route information and a bool value of `true`. If a route can not be matched,
|
||||||
|
//a nil value for the route and false bool value is returned.
|
||||||
|
func (p *Proxy) router(req *http.Request) (http.Handler, bool) {
|
||||||
|
route, ok := p.mux[req.Host]
|
||||||
|
if ok {
|
||||||
|
return *route, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRedirectURL returns the redirect url for a given Proxy,
|
||||||
|
// setting the scheme to be https if CookieSecure is true.
|
||||||
|
func (p *Proxy) GetRedirectURL(host string) *url.URL {
|
||||||
|
// TODO: Ensure that we only allow valid upstream hosts in redirect URIs
|
||||||
|
u := p.redirectURL
|
||||||
|
// Build redirect URI from request host
|
||||||
|
if u.Scheme == "" {
|
||||||
|
u.Scheme = "https"
|
||||||
|
}
|
||||||
|
u.Host = host
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) redeemCode(host, code string) (*sessions.SessionState, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("missing code")
|
||||||
|
}
|
||||||
|
redirectURL := p.GetRedirectURL(host)
|
||||||
|
s, err := p.authenticateClient.Redeem(redirectURL.String(), code)
|
||||||
|
if err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Email == "" {
|
||||||
|
return s, errors.New("invalid email address")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
331
proxy/proxy.go
Executable file
331
proxy/proxy.go
Executable file
|
@ -0,0 +1,331 @@
|
||||||
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/envconfig"
|
||||||
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options represents the configuration options for the proxy service.
|
||||||
|
type Options struct {
|
||||||
|
// AuthenticateServiceURL specifies the url to the pomerium authenticate http service.
|
||||||
|
AuthenticateServiceURL *url.URL `envconfig:"PROVIDER_URL"`
|
||||||
|
|
||||||
|
// EmailDomains is a string slice of valid domains to proxy
|
||||||
|
EmailDomains []string `envconfig:"EMAIL_DOMAIN"`
|
||||||
|
// todo(bdd): ClientID and ClientSecret are used are a hacky pre shared key
|
||||||
|
// prefer certificates and mutual tls
|
||||||
|
ClientID string `envconfig:"PROXY_CLIENT_ID"`
|
||||||
|
ClientSecret string `envconfig:"PROXY_CLIENT_SECRET"`
|
||||||
|
|
||||||
|
DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT"`
|
||||||
|
|
||||||
|
CookieName string `envconfig:"COOKIE_NAME"`
|
||||||
|
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||||
|
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||||
|
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
|
||||||
|
CookieSecure bool `envconfig:"COOKIE_SECURE" `
|
||||||
|
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
|
||||||
|
|
||||||
|
PassAccessToken bool `envconfig:"PASS_ACCESS_TOKEN"`
|
||||||
|
|
||||||
|
// session details
|
||||||
|
SessionValidTTL time.Duration `envconfig:"SESSION_VALID_TTL"`
|
||||||
|
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL"`
|
||||||
|
GracePeriodTTL time.Duration `envconfig:"GRACE_PERIOD_TTL"`
|
||||||
|
|
||||||
|
Routes map[string]string `envconfig:"ROUTES"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOptions returns a new options struct
|
||||||
|
var defaultOptions = &Options{
|
||||||
|
CookieName: "_pomerium_proxy",
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: true,
|
||||||
|
CookieExpire: time.Duration(168) * time.Hour,
|
||||||
|
DefaultUpstreamTimeout: time.Duration(10) * time.Second,
|
||||||
|
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||||
|
SessionValidTTL: time.Duration(1) * time.Minute,
|
||||||
|
GracePeriodTTL: time.Duration(3) * time.Hour,
|
||||||
|
PassAccessToken: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionsFromEnvConfig builds the authentication service's configuration
|
||||||
|
// options from provided environmental variables
|
||||||
|
func OptionsFromEnvConfig() (*Options, error) {
|
||||||
|
o := defaultOptions
|
||||||
|
if err := envconfig.Process("", o); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks that proper configuration settings are set to create
|
||||||
|
// a proper Proxy instance
|
||||||
|
func (o *Options) Validate() error {
|
||||||
|
if len(o.Routes) == 0 {
|
||||||
|
return errors.New("missing setting: routes")
|
||||||
|
}
|
||||||
|
for to, from := range o.Routes {
|
||||||
|
if _, err := urlParse(to); err != nil {
|
||||||
|
return fmt.Errorf("could not parse origin %s as url : %q", to, err)
|
||||||
|
}
|
||||||
|
if _, err := urlParse(from); err != nil {
|
||||||
|
return fmt.Errorf("could not parse destination %s as url : %q", to, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if o.AuthenticateServiceURL == nil {
|
||||||
|
return errors.New("missing setting: provider-url")
|
||||||
|
}
|
||||||
|
if o.CookieSecret == "" {
|
||||||
|
return errors.New("missing setting: cookie-secret")
|
||||||
|
}
|
||||||
|
if o.ClientID == "" {
|
||||||
|
return errors.New("missing setting: client-id")
|
||||||
|
}
|
||||||
|
if o.ClientSecret == "" {
|
||||||
|
return errors.New("missing setting: client-secret")
|
||||||
|
}
|
||||||
|
if len(o.EmailDomains) == 0 {
|
||||||
|
return errors.New("missing setting: email-domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("cookie secret is invalid (e.g. `head -c33 /dev/urandom | base64`) ")
|
||||||
|
}
|
||||||
|
validCookieSecretLength := false
|
||||||
|
for _, i := range []int{32, 64} {
|
||||||
|
if len(decodedCookieSecret) == i {
|
||||||
|
validCookieSecretLength = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !validCookieSecretLength {
|
||||||
|
return fmt.Errorf("cookie secret is invalid, must be 32 or 64 bytes but got %d bytes (e.g. `head -c33 /dev/urandom | base64`) ", len(decodedCookieSecret))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy stores all the information associated with proxying a request.
|
||||||
|
type Proxy struct {
|
||||||
|
CookieCipher aead.Cipher
|
||||||
|
CookieDomain string
|
||||||
|
CookieExpire time.Duration
|
||||||
|
CookieHTTPOnly bool
|
||||||
|
CookieName string
|
||||||
|
CookieSecure bool
|
||||||
|
CookieSeed string
|
||||||
|
CSRFCookieName string
|
||||||
|
EmailValidator func(string) bool
|
||||||
|
|
||||||
|
PassAccessToken bool
|
||||||
|
|
||||||
|
// services
|
||||||
|
authenticateClient *authenticator.AuthenticateClient
|
||||||
|
// session
|
||||||
|
csrfStore sessions.CSRFStore
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
cipher aead.Cipher
|
||||||
|
|
||||||
|
redirectURL *url.URL // the url to receive requests at
|
||||||
|
templates *template.Template
|
||||||
|
mux map[string]*http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateParameter holds the redirect id along with the session id.
|
||||||
|
type StateParameter struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxy takes a Proxy service from options and a validation function.
|
||||||
|
// Function returns an error if options fail to validate.
|
||||||
|
func NewProxy(opts *Options, optFuncs ...func(*Proxy) error) (*Proxy, error) {
|
||||||
|
if opts == nil {
|
||||||
|
return nil, errors.New("options cannot be nil")
|
||||||
|
}
|
||||||
|
if err := opts.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// error explicitly handled by validate
|
||||||
|
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
|
cipher, err := aead.NewMiscreantCipher(decodedSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
|
||||||
|
sessions.CreateMiscreantCookieCipher(decodedSecret),
|
||||||
|
func(c *sessions.CookieStore) error {
|
||||||
|
c.CookieDomain = opts.CookieDomain
|
||||||
|
c.CookieHTTPOnly = opts.CookieHTTPOnly
|
||||||
|
c.CookieExpire = opts.CookieExpire
|
||||||
|
c.CookieSecure = opts.CookieSecure
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authClient := authenticator.NewAuthenticateClient(
|
||||||
|
opts.AuthenticateServiceURL,
|
||||||
|
// todo(bdd): fields below can be dropped as Client data provides redudent auth
|
||||||
|
opts.ClientID,
|
||||||
|
opts.ClientSecret,
|
||||||
|
// todo(bdd): fields below should be passed as function args
|
||||||
|
opts.SessionLifetimeTTL,
|
||||||
|
opts.SessionValidTTL,
|
||||||
|
opts.GracePeriodTTL,
|
||||||
|
)
|
||||||
|
|
||||||
|
p := &Proxy{
|
||||||
|
CookieCipher: cipher,
|
||||||
|
CookieDomain: opts.CookieDomain,
|
||||||
|
CookieExpire: opts.CookieExpire,
|
||||||
|
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||||
|
CookieName: opts.CookieName,
|
||||||
|
CookieSecure: opts.CookieSecure,
|
||||||
|
CookieSeed: string(decodedSecret),
|
||||||
|
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
||||||
|
|
||||||
|
// these fields make up the routing mechanism
|
||||||
|
mux: make(map[string]*http.Handler),
|
||||||
|
// session state
|
||||||
|
csrfStore: cookieStore,
|
||||||
|
sessionStore: cookieStore,
|
||||||
|
cipher: cipher,
|
||||||
|
|
||||||
|
authenticateClient: authClient,
|
||||||
|
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||||
|
templates: templates.New(),
|
||||||
|
PassAccessToken: opts.PassAccessToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, optFunc := range optFuncs {
|
||||||
|
err := optFunc(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for from, to := range opts.Routes {
|
||||||
|
fromURL, _ := urlParse(from)
|
||||||
|
toURL, _ := urlParse(to)
|
||||||
|
reverseProxy := NewReverseProxy(toURL)
|
||||||
|
handler := NewReverseProxyHandler(opts, reverseProxy, toURL.String())
|
||||||
|
p.Handle(fromURL.Host, handler)
|
||||||
|
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.NewProxy : route")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("CookieName", p.CookieName).
|
||||||
|
Str("redirectURL", p.redirectURL.String()).
|
||||||
|
Str("CSRFCookieName", p.CSRFCookieName).
|
||||||
|
Bool("CookieSecure", p.CookieSecure).
|
||||||
|
Str("CookieDomain", p.CookieDomain).
|
||||||
|
Bool("CookieHTTPOnly", p.CookieHTTPOnly).
|
||||||
|
Dur("CookieExpire", opts.CookieExpire).
|
||||||
|
Msg("proxy.NewProxy")
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamProxy stores information necessary for proxying the request back to the upstream.
|
||||||
|
type UpstreamProxy struct {
|
||||||
|
name string
|
||||||
|
cookieName string
|
||||||
|
handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultUpstreamTransport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
DualStack: true,
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteSSOCookieHeader deletes the session cookie from the request header string.
|
||||||
|
func deleteSSOCookieHeader(req *http.Request, cookieName string) {
|
||||||
|
headers := []string{}
|
||||||
|
for _, cookie := range req.Cookies() {
|
||||||
|
if cookie.Name != cookieName {
|
||||||
|
headers = append(headers, cookie.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP signs the http request and deletes cookie headers
|
||||||
|
// before calling the upstream's ServeHTTP function.
|
||||||
|
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestLog := log.WithRequest(r, "proxy.ServeHTTP")
|
||||||
|
deleteSSOCookieHeader(r, u.cookieName)
|
||||||
|
start := time.Now()
|
||||||
|
u.handler.ServeHTTP(w, r)
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
requestLog.Debug().Dur("duration", duration).Msg("proxy-request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReverseProxy creates a reverse proxy to a specified url.
|
||||||
|
// It adds an X-Forwarded-Host header that is the request's host.
|
||||||
|
//
|
||||||
|
// todo(bdd): when would we ever want to preserve host?
|
||||||
|
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(to)
|
||||||
|
proxy.Transport = defaultUpstreamTransport
|
||||||
|
director := proxy.Director
|
||||||
|
proxy.Director = func(req *http.Request) {
|
||||||
|
req.Header.Add("X-Forwarded-Host", req.Host)
|
||||||
|
director(req)
|
||||||
|
req.Host = to.Host
|
||||||
|
}
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReverseProxyHandler applies handler specific options to a given
|
||||||
|
// route.
|
||||||
|
func NewReverseProxyHandler(opts *Options, reverseProxy *httputil.ReverseProxy, serviceName string) http.Handler {
|
||||||
|
upstreamProxy := &UpstreamProxy{
|
||||||
|
name: serviceName,
|
||||||
|
handler: reverseProxy,
|
||||||
|
cookieName: opts.CookieName,
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := opts.DefaultUpstreamTimeout
|
||||||
|
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", serviceName, timeout)
|
||||||
|
return http.TimeoutHandler(upstreamProxy, timeout, timeoutMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// urlParse adds a scheme if none-exists, addressesing a quirk in how
|
||||||
|
// one may expect url.Parse to function when a "naked" domain is sent.
|
||||||
|
//
|
||||||
|
// see: https://github.com/golang/go/issues/12585
|
||||||
|
// see: https://golang.org/pkg/net/url/#Parse
|
||||||
|
func urlParse(uri string) (*url.URL, error) {
|
||||||
|
if !strings.Contains(uri, "://") {
|
||||||
|
uri = fmt.Sprintf("https://%s", uri)
|
||||||
|
}
|
||||||
|
return url.Parse(uri)
|
||||||
|
}
|
220
proxy/proxy_test.go
Normal file
220
proxy/proxy_test.go
Normal file
|
@ -0,0 +1,220 @@
|
||||||
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *Options
|
||||||
|
envKey string
|
||||||
|
envValue string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good default, no env settings", defaultOptions, "", "", false},
|
||||||
|
{"bad url", nil, "PROVIDER_URL", "%.rjlw", true},
|
||||||
|
{"good duration", defaultOptions, "SESSION_VALID_TTL", "1m", false},
|
||||||
|
{"bad duration", nil, "SESSION_VALID_TTL", "1sm", true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.envKey != "" {
|
||||||
|
os.Setenv(tt.envKey, tt.envValue)
|
||||||
|
defer os.Unsetenv(tt.envKey)
|
||||||
|
}
|
||||||
|
got, err := OptionsFromEnvConfig()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OptionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("OptionsFromEnvConfig() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_urlParse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uri string
|
||||||
|
want *url.URL
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good url without schema", "accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
|
||||||
|
{"good url with schema", "https://accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
|
||||||
|
{"bad url, malformed", "https://accounts.google.^", nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := urlParse(tt.uri)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("urlParse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("urlParse() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewReverseProxy(t *testing.T) {
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||||
|
w.Write([]byte(hostname))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
backendURL, _ := url.Parse(backend.URL)
|
||||||
|
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||||
|
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||||
|
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||||
|
|
||||||
|
proxyHandler := NewReverseProxy(proxyURL)
|
||||||
|
frontend := httptest.NewServer(proxyHandler)
|
||||||
|
defer frontend.Close()
|
||||||
|
|
||||||
|
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||||
|
res, _ := http.DefaultClient.Do(getReq)
|
||||||
|
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||||
|
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||||
|
t.Errorf("got body %q; expected %q", g, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewReverseProxyHandler(t *testing.T) {
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||||
|
w.Write([]byte(hostname))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
backendURL, _ := url.Parse(backend.URL)
|
||||||
|
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||||
|
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||||
|
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||||
|
|
||||||
|
proxyHandler := NewReverseProxy(proxyURL)
|
||||||
|
opts := defaultOptions
|
||||||
|
handle := NewReverseProxyHandler(opts, proxyHandler, "name")
|
||||||
|
|
||||||
|
frontend := httptest.NewServer(handle)
|
||||||
|
|
||||||
|
defer frontend.Close()
|
||||||
|
|
||||||
|
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||||
|
|
||||||
|
res, _ := http.DefaultClient.Do(getReq)
|
||||||
|
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||||
|
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||||
|
t.Errorf("got body %q; expected %q", g, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testOptions() *Options {
|
||||||
|
authurl, _ := url.Parse("https://sso-auth.corp.beyondperimeter.com")
|
||||||
|
return &Options{
|
||||||
|
Routes: map[string]string{"corp.example.com": "example.com"},
|
||||||
|
AuthenticateServiceURL: authurl,
|
||||||
|
ClientID: "yksYDhIM7PZTvdFP3Mi3sYt2JXooTi7y0oIClBR46fs=",
|
||||||
|
ClientSecret: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
|
EmailDomains: []string{"*"},
|
||||||
|
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptions_Validate(t *testing.T) {
|
||||||
|
good := testOptions()
|
||||||
|
badFromRoute := testOptions()
|
||||||
|
badFromRoute.Routes = map[string]string{"example.com": "^"}
|
||||||
|
badToRoute := testOptions()
|
||||||
|
badToRoute.Routes = map[string]string{"^": "example.com"}
|
||||||
|
badAuthURL := testOptions()
|
||||||
|
badAuthURL.AuthenticateServiceURL = nil
|
||||||
|
emptyCookieSecret := testOptions()
|
||||||
|
emptyCookieSecret.CookieSecret = ""
|
||||||
|
invalidCookieSecret := testOptions()
|
||||||
|
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
||||||
|
shortCookieLength := testOptions()
|
||||||
|
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" //head -c31 /dev/urandom | base64
|
||||||
|
|
||||||
|
badClientID := testOptions()
|
||||||
|
badClientID.ClientID = ""
|
||||||
|
badClientSecret := testOptions()
|
||||||
|
badClientSecret.ClientSecret = ""
|
||||||
|
badEmailDomain := testOptions()
|
||||||
|
badEmailDomain.EmailDomains = nil
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
o *Options
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good - minimum options", good, false},
|
||||||
|
|
||||||
|
{"bad - nil options", &Options{}, true},
|
||||||
|
{"bad - from route", badFromRoute, true},
|
||||||
|
{"bad - to route", badToRoute, true},
|
||||||
|
{"bad - auth service url", badAuthURL, true},
|
||||||
|
{"bad - no cookie secret", emptyCookieSecret, true},
|
||||||
|
{"bad - invalid cookie secret", invalidCookieSecret, true},
|
||||||
|
{"bad - short cookie secret", shortCookieLength, true},
|
||||||
|
{"bad - no client id", badClientID, true},
|
||||||
|
{"bad - no client secret", badClientSecret, true},
|
||||||
|
{"bad - no email domain", badEmailDomain, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := tt.o
|
||||||
|
if err := o.Validate(); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProxy(t *testing.T) {
|
||||||
|
good := testOptions()
|
||||||
|
shortCookieLength := testOptions()
|
||||||
|
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" //head -c31 /dev/urandom | base64
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts *Options
|
||||||
|
optFuncs []func(*Proxy) error
|
||||||
|
wantProxy bool
|
||||||
|
numMuxes int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good - minimum options", good, nil, true, 1, false},
|
||||||
|
{"bad - empty options", &Options{}, nil, false, 0, true},
|
||||||
|
{"bad - nil options", nil, nil, false, 0, true},
|
||||||
|
{"bad - short secret/validate sanity check", shortCookieLength, nil, false, 0, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NewProxy(tt.opts, tt.optFuncs...)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewProxy() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got == nil && tt.wantProxy == true {
|
||||||
|
t.Errorf("NewProxy() expected valid proxy struct")
|
||||||
|
}
|
||||||
|
if got != nil && len(got.mux) != tt.numMuxes {
|
||||||
|
t.Errorf("NewProxy() = num muxes %v, want %v", got, tt.numMuxes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
8
proxy/testdata/public_key.pub
vendored
Normal file
8
proxy/testdata/public_key.pub
vendored
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
-----BEGIN RSA PUBLIC KEY-----
|
||||||
|
MIIBCgKCAQEAst/CEAh/EMnjRbNcwNF7iMqp03En2GYNJz3wfiv/6Rcu7SDgMJke
|
||||||
|
rYfDcpK8RYREAxyjQpi17eI/FRQx0GbRo1AR0ZgF2VvDTkNBCNb3Pw6bdPbFONCU
|
||||||
|
JV2WXi/vf+4gMRH0hN00K9ZOz18MaY5va7C0p+xaC5713KNJnOvndo48X+HDICSG
|
||||||
|
kCyjne/NylEMy1RLwUCdOSZ6SNsTI0tKt95bTEzBhd0GUDfYuG2SoJyLaJisUyW3
|
||||||
|
8X7TtdRUzSwe6IPeLFppU4QGOf1DI2WlmCdYPPfllCfiqVWMibBzwQZGkBvjWGs3
|
||||||
|
Cw8iKMKcydVlZCJ8rLIaU6sE/lD1eGrfowIDAQAB
|
||||||
|
-----END RSA PUBLIC KEY-----
|
10
proxy/testdata/upstream_configs.yml
vendored
Normal file
10
proxy/testdata/upstream_configs.yml
vendored
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
- service: foo
|
||||||
|
default:
|
||||||
|
from: foo.{{cluster}}.{{root_domain}}
|
||||||
|
to: foo-internal.{{cluster}}.{{root_domain}}
|
||||||
|
options:
|
||||||
|
allowed_groups:
|
||||||
|
- dev
|
||||||
|
dev:
|
||||||
|
from: foo.{{root_domain}}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue