mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-25 14:08:09 +02:00
proxy: add JWT request signing support (#19)
- Refactored middleware and request hander logging. - Request refactored to use context.Context. - Add helper (based on Alice) to allow middleware chaining. - Add helper scripts to generate elliptic curve self-signed certificate that can be used to sign JWT. - Changed LetsEncrypt scripts to use acme instead of certbot. - Add script to have LetsEncrypt sign an RSA based certificate. - Add documentation to explain how to verify headers. - Refactored internal/cryptutil signer's code to expect a valid EC priv key. - Changed JWT expiries to use default leeway period. - Update docs and add screenshots. - Replaced logging handler logic to use context.Context. - Removed specific XML error handling. - Refactored handler function signatures to prefer standard go idioms.
This commit is contained in:
parent
98b8c7481f
commit
426e003b03
30 changed files with 1711 additions and 588 deletions
25
3RD-PARTY
25
3RD-PARTY
|
@ -86,3 +86,28 @@ https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
THE SOFTWARE.
|
THE SOFTWARE.
|
||||||
|
|
||||||
|
alice
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
https://github.com/justinas/alice/blob/master/LICENSE
|
||||||
|
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2014 Justinas Stankevicius
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
|
this software and associated documentation files (the "Software"), to deal in
|
||||||
|
the Software without restriction, including without limitation the rights to
|
||||||
|
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||||
|
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||||
|
subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||||
|
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||||
|
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||||
|
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -54,7 +54,7 @@ func (p *Authenticate) Handler() http.Handler {
|
||||||
}
|
}
|
||||||
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
|
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
|
||||||
|
|
||||||
return m.SetHeaders(mux, securityHeaders)
|
return m.SetHeadersOld(mux, securityHeaders)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateSignature wraps a common collection of middlewares to validate signatures
|
// validateSignature wraps a common collection of middlewares to validate signatures
|
||||||
|
@ -70,21 +70,21 @@ func (p *Authenticate) validateExisting(f http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt handles the /robots.txt route.
|
// RobotsTxt handles the /robots.txt route.
|
||||||
func (p *Authenticate) RobotsTxt(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
// PingPage handles the /ping route
|
// PingPage handles the /ping route
|
||||||
func (p *Authenticate) PingPage(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) PingPage(w http.ResponseWriter, r *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(rw, "OK")
|
fmt.Fprintf(w, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
||||||
func (p *Authenticate) SignInPage(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
||||||
requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
// requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
||||||
redirectURL := p.RedirectURL.ResolveReference(req.URL)
|
redirectURL := p.RedirectURL.ResolveReference(r.URL)
|
||||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||||
t := struct {
|
t := struct {
|
||||||
|
@ -100,72 +100,72 @@ func (p *Authenticate) SignInPage(rw http.ResponseWriter, req *http.Request) {
|
||||||
Destination: destinationURL.Host,
|
Destination: destinationURL.Host,
|
||||||
Version: version.FullVersion(),
|
Version: version.FullVersion(),
|
||||||
}
|
}
|
||||||
requestLog.Info().
|
log.Ctx(r.Context()).Info().
|
||||||
Str("ProviderName", p.provider.Data().ProviderName).
|
Str("ProviderName", p.provider.Data().ProviderName).
|
||||||
Str("Redirect", redirectURL.String()).
|
Str("Redirect", redirectURL.String()).
|
||||||
Str("Destination", destinationURL.Host).
|
Str("Destination", destinationURL.Host).
|
||||||
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
||||||
Msg("authenticate.SignInPage")
|
Msg("authenticate.SignInPage")
|
||||||
rw.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
p.templates.ExecuteTemplate(w, "sign_in.html", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticate) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
|
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
||||||
requestLog := log.WithRequest(req, "authenticate.authenticate")
|
// requestLog := log.WithRequest(req, "authenticate.authenticate")
|
||||||
session, err := p.sessionStore.LoadSession(req)
|
session, err := p.sessionStore.LoadSession(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("authenticate.authenticate")
|
log.Error().Err(err).Msg("authenticate.authenticate")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure sessions lifetime has not expired
|
// ensure sessions lifetime has not expired
|
||||||
if session.LifetimePeriodExpired() {
|
if session.LifetimePeriodExpired() {
|
||||||
requestLog.Warn().Msg("lifetime expired")
|
log.Ctx(r.Context()).Warn().Msg("lifetime expired")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, sessions.ErrLifetimeExpired
|
return nil, sessions.ErrLifetimeExpired
|
||||||
}
|
}
|
||||||
// check if session refresh period is up
|
// check if session refresh period is up
|
||||||
if session.RefreshPeriodExpired() {
|
if session.RefreshPeriodExpired() {
|
||||||
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("failed to refresh session")
|
log.Ctx(r.Context()).Error().Err(err).Msg("failed to refresh session")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
requestLog.Error().Msg("user unauthorized after refresh")
|
log.Ctx(r.Context()).Error().Msg("user unauthorized after refresh")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
// update refresh'd session in cookie
|
// update refresh'd session in cookie
|
||||||
err = p.sessionStore.SaveSession(rw, req, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We refreshed the session successfully, but failed to save it.
|
// We refreshed the session successfully, but failed to save it.
|
||||||
// This could be from failing to encode the session properly.
|
// This could be from failing to encode the session properly.
|
||||||
// But, we clear the session cookie and reject the request
|
// But, we clear the session cookie and reject the request
|
||||||
requestLog.Error().Err(err).Msg("could not save refreshed session")
|
log.Ctx(r.Context()).Error().Err(err).Msg("could not save refreshed session")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The session has not exceeded it's lifetime or requires refresh
|
// The session has not exceeded it's lifetime or requires refresh
|
||||||
ok := p.provider.ValidateSessionState(session)
|
ok := p.provider.ValidateSessionState(session)
|
||||||
if !ok {
|
if !ok {
|
||||||
requestLog.Error().Msg("invalid session state")
|
log.Ctx(r.Context()).Error().Msg("invalid session state")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
err = p.sessionStore.SaveSession(rw, req, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("failed to save valid session")
|
log.Ctx(r.Context()).Error().Err(err).Msg("failed to save valid session")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !p.Validator(session.Email) {
|
if !p.Validator(session.Email) {
|
||||||
requestLog.Error().Msg("invalid email user")
|
log.Ctx(r.Context()).Error().Msg("invalid email user")
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
|
@ -173,7 +173,7 @@ func (p *Authenticate) authenticate(rw http.ResponseWriter, req *http.Request) (
|
||||||
|
|
||||||
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
||||||
// and if the user is not authenticated, it renders a sign in page.
|
// and if the user is not authenticated, it renders a sign in page.
|
||||||
func (p *Authenticate) SignIn(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||||
// We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
|
// We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
|
||||||
// page.
|
// page.
|
||||||
//
|
//
|
||||||
|
@ -183,35 +183,35 @@ func (p *Authenticate) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||||
// TODO: It is possible for a user to visit this page without a redirect destination.
|
// TODO: It is possible for a user to visit this page without a redirect destination.
|
||||||
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
|
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
|
||||||
|
|
||||||
session, err := p.authenticate(rw, req)
|
session, err := p.authenticate(w, r)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
// User is authenticated, redirect back to the proxy application
|
// User is authenticated, redirect back to the proxy application
|
||||||
// with the necessary state
|
// with the necessary state
|
||||||
p.ProxyOAuthRedirect(rw, req, session)
|
p.ProxyOAuthRedirect(w, r, session)
|
||||||
case http.ErrNoCookie:
|
case http.ErrNoCookie:
|
||||||
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
||||||
if p.skipProviderButton {
|
if p.skipProviderButton {
|
||||||
p.skipButtonOAuthStart(rw, req)
|
p.skipButtonOAuthStart(w, r)
|
||||||
} else {
|
} else {
|
||||||
p.SignInPage(rw, req)
|
p.SignInPage(w, r)
|
||||||
}
|
}
|
||||||
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||||
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
if p.skipProviderButton {
|
if p.skipProviderButton {
|
||||||
p.skipButtonOAuthStart(rw, req)
|
p.skipButtonOAuthStart(w, r)
|
||||||
} else {
|
} else {
|
||||||
p.SignInPage(rw, req)
|
p.SignInPage(w, r)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyOAuthRedirect redirects the user back to sso proxy's redirection endpoint.
|
// ProxyOAuthRedirect redirects the user back to sso proxy's redirection endpoint.
|
||||||
func (p *Authenticate) ProxyOAuthRedirect(rw http.ResponseWriter, req *http.Request, session *sessions.SessionState) {
|
func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
|
||||||
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
||||||
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
||||||
//
|
//
|
||||||
|
@ -223,36 +223,36 @@ func (p *Authenticate) ProxyOAuthRedirect(rw http.ResponseWriter, req *http.Requ
|
||||||
//
|
//
|
||||||
// We must also include the original `state` parameter received from the proxy application.
|
// We must also include the original `state` parameter received from the proxy application.
|
||||||
|
|
||||||
err := req.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
state := req.Form.Get("state")
|
state := r.Form.Get("state")
|
||||||
if state == "" {
|
if state == "" {
|
||||||
httputil.ErrorResponse(rw, req, "no state parameter supplied", http.StatusForbidden)
|
httputil.ErrorResponse(w, r, "no state parameter supplied", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
redirectURI := r.Form.Get("redirect_uri")
|
||||||
if redirectURI == "" {
|
if redirectURI == "" {
|
||||||
httputil.ErrorResponse(rw, req, "no redirect_uri parameter supplied", http.StatusForbidden)
|
httputil.ErrorResponse(w, r, "no redirect_uri parameter supplied", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL, err := url.Parse(redirectURI)
|
redirectURL, err := url.Parse(redirectURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(rw, req, "malformed redirect_uri parameter passed", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "malformed redirect_uri parameter passed", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
encrypted, err := sessions.MarshalSession(session, p.cipher)
|
encrypted, err := sessions.MarshalSession(session, p.cipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.Redirect(rw, req, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound)
|
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
|
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
|
||||||
|
@ -271,55 +271,56 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOut signs the user out.
|
// SignOut signs the user out.
|
||||||
func (p *Authenticate) SignOut(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
redirectURI := r.Form.Get("redirect_uri")
|
||||||
if req.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
p.SignOutPage(rw, req, "")
|
p.SignOutPage(w, r, "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := p.sessionStore.LoadSession(req)
|
session, err := p.sessionStore.LoadSession(r)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
|
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
|
||||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// a different error, clear the session cookie and redirect
|
// a different error, clear the session cookie and redirect
|
||||||
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
|
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.provider.Revoke(session)
|
err = p.provider.Revoke(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
|
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
|
||||||
p.SignOutPage(rw, req, "An error occurred during sign out. Please try again.")
|
p.SignOutPage(w, r, "An error occurred during sign out. Please try again.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOutPage renders a sign out page with a message
|
// SignOutPage renders a sign out page with a message
|
||||||
func (p *Authenticate) SignOutPage(rw http.ResponseWriter, req *http.Request, message string) {
|
func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, message string) {
|
||||||
|
log.FromRequest(r).Debug().Msg("This is just a test to make sure signout works")
|
||||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
redirectURI := r.Form.Get("redirect_uri")
|
||||||
session, err := p.sessionStore.LoadSession(req)
|
session, err := p.sessionStore.LoadSession(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(rw, req, redirectURI, http.StatusFound)
|
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
signature := req.Form.Get("sig")
|
signature := r.Form.Get("sig")
|
||||||
timestamp := req.Form.Get("ts")
|
timestamp := r.Form.Get("ts")
|
||||||
destinationURL, _ := url.Parse(redirectURI)
|
destinationURL, _ := url.Parse(redirectURI)
|
||||||
|
|
||||||
// An error message indicates that an internal server error occurred
|
// An error message indicates that an internal server error occurred
|
||||||
if message != "" {
|
if message != "" {
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := struct {
|
t := struct {
|
||||||
|
@ -339,45 +340,45 @@ func (p *Authenticate) SignOutPage(rw http.ResponseWriter, req *http.Request, me
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
Version: version.FullVersion(),
|
Version: version.FullVersion(),
|
||||||
}
|
}
|
||||||
p.templates.ExecuteTemplate(rw, "sign_out.html", t)
|
p.templates.ExecuteTemplate(w, "sign_out.html", t)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
||||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
||||||
func (p *Authenticate) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
|
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.helperOAuthStart(rw, req, authRedirectURL)
|
p.helperOAuthStart(w, r, authRedirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticate) skipButtonOAuthStart(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) skipButtonOAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
p.helperOAuthStart(rw, req, p.RedirectURL.ResolveReference(req.URL))
|
p.helperOAuthStart(w, r, p.RedirectURL.ResolveReference(r.URL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticate) helperOAuthStart(rw http.ResponseWriter, req *http.Request, authRedirectURL *url.URL) {
|
func (p *Authenticate) helperOAuthStart(w http.ResponseWriter, r *http.Request, authRedirectURL *url.URL) {
|
||||||
|
|
||||||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||||
p.csrfStore.SetCSRF(rw, req, nonce)
|
p.csrfStore.SetCSRF(w, r, nonce)
|
||||||
|
|
||||||
if !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
if !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||||
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
||||||
ts := authRedirectURL.Query().Get("ts")
|
ts := authRedirectURL.Query().Get("ts")
|
||||||
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -385,7 +386,7 @@ func (p *Authenticate) helperOAuthStart(rw http.ResponseWriter, req *http.Reques
|
||||||
|
|
||||||
signInURL := p.provider.GetSignInURL(state)
|
signInURL := p.provider.GetSignInURL(state)
|
||||||
|
|
||||||
http.Redirect(rw, req, signInURL, http.StatusFound)
|
http.Redirect(w, r, signInURL, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, error) {
|
func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, error) {
|
||||||
|
@ -402,29 +403,29 @@ func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, er
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||||
func (p *Authenticate) getOAuthCallback(rw http.ResponseWriter, req *http.Request) (string, error) {
|
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||||
requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
|
// requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
|
||||||
// finish the oauth cycle
|
// finish the oauth cycle
|
||||||
err := req.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||||
}
|
}
|
||||||
errorString := req.Form.Get("error")
|
errorString := r.Form.Get("error")
|
||||||
if errorString != "" {
|
if errorString != "" {
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString}
|
||||||
}
|
}
|
||||||
code := req.Form.Get("code")
|
code := r.Form.Get("code")
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := p.redeemCode(req.Host, code)
|
session, err := p.redeemCode(r.Host, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("error redeeming authentication code")
|
log.Ctx(r.Context()).Error().Err(err).Msg("error redeeming authentication code")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes, err := base64.URLEncoding.DecodeString(req.Form.Get("state"))
|
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||||
}
|
}
|
||||||
|
@ -434,13 +435,13 @@ func (p *Authenticate) getOAuthCallback(rw http.ResponseWriter, req *http.Reques
|
||||||
}
|
}
|
||||||
nonce := s[0]
|
nonce := s[0]
|
||||||
redirect := s[1]
|
redirect := s[1]
|
||||||
c, err := p.csrfStore.GetCSRF(req)
|
c, err := p.csrfStore.GetCSRF(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
||||||
}
|
}
|
||||||
p.csrfStore.ClearCSRF(rw, req)
|
p.csrfStore.ClearCSRF(w, r)
|
||||||
if c.Value != nonce {
|
if c.Value != nonce {
|
||||||
requestLog.Error().Err(err).Msg("csrf token mismatch")
|
log.Ctx(r.Context()).Error().Err(err).Msg("csrf token mismatch")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -452,13 +453,13 @@ func (p *Authenticate) getOAuthCallback(rw http.ResponseWriter, req *http.Reques
|
||||||
// - for p.Validator see validator.go#newValidatorImpl for more info
|
// - for p.Validator see validator.go#newValidatorImpl for more info
|
||||||
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
||||||
if !p.Validator(session.Email) {
|
if !p.Validator(session.Email) {
|
||||||
requestLog.Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
log.Ctx(r.Context()).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
||||||
}
|
}
|
||||||
requestLog.Info().Str("email", session.Email).Msg("authentication complete")
|
log.Ctx(r.Context()).Info().Str("email", session.Email).Msg("authentication complete")
|
||||||
err = p.sessionStore.SaveSession(rw, req, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("internal error")
|
log.Ctx(r.Context()).Error().Err(err).Msg("internal error")
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
|
||||||
}
|
}
|
||||||
return redirect, nil
|
return redirect, nil
|
||||||
|
@ -466,49 +467,49 @@ func (p *Authenticate) getOAuthCallback(rw http.ResponseWriter, req *http.Reques
|
||||||
|
|
||||||
// OAuthCallback handles the callback from the provider, and returns an error response if there is an error.
|
// OAuthCallback handles the callback from the provider, and returns an error response if there is an error.
|
||||||
// If there is no error it will redirect to the redirect url.
|
// If there is no error it will redirect to the redirect url.
|
||||||
func (p *Authenticate) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
redirect, err := p.getOAuthCallback(rw, req)
|
redirect, err := p.getOAuthCallback(w, r)
|
||||||
switch h := err.(type) {
|
switch h := err.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
break
|
break
|
||||||
case httputil.HTTPError:
|
case httputil.HTTPError:
|
||||||
httputil.ErrorResponse(rw, req, h.Message, h.Code)
|
httputil.ErrorResponse(w, r, h.Message, h.Code)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
httputil.ErrorResponse(rw, req, "Internal Error", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.Redirect(rw, req, redirect, http.StatusFound)
|
http.Redirect(w, r, redirect, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redeem has a signed access token, and provides the user information associated with the access token.
|
// Redeem has a signed access token, and provides the user information associated with the access token.
|
||||||
func (p *Authenticate) Redeem(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
||||||
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
||||||
// expiration, and email.
|
// expiration, and email.
|
||||||
requestLog := log.WithRequest(req, "authenticate.Redeem")
|
// requestLog := log.WithRequest(req, "authenticate.Redeem")
|
||||||
err := req.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := sessions.UnmarshalSession(req.Form.Get("code"), p.cipher)
|
session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
|
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
|
||||||
http.Error(rw, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if session == nil {
|
if session == nil {
|
||||||
requestLog.Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
|
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
|
||||||
http.Error(rw, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
||||||
requestLog.Error().Msg("expired session")
|
log.Ctx(r.Context()).Error().Msg("expired session")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
http.Error(rw, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -528,32 +529,32 @@ func (p *Authenticate) Redeem(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(response)
|
jsonBytes, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rw.Header().Set("GAP-Auth", session.Email)
|
w.Header().Set("GAP-Auth", session.Email)
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
rw.Write(jsonBytes)
|
w.Write(jsonBytes)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh takes a refresh token and returns a new access token
|
// Refresh takes a refresh token and returns a new access token
|
||||||
func (p *Authenticate) Refresh(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||||
err := req.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken := req.Form.Get("refresh_token")
|
refreshToken := r.Form.Get("refresh_token")
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
http.Error(rw, "Bad Request: No Refresh Token", http.StatusBadRequest)
|
http.Error(w, "Bad Request: No Refresh Token", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
accessToken, expiresIn, err := p.provider.RefreshAccessToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,28 +568,28 @@ func (p *Authenticate) Refresh(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
bytes, err := json.Marshal(response)
|
bytes, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
rw.Write(bytes)
|
w.Write(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfile gets a list of groups of which a user is a member.
|
// GetProfile gets a list of groups of which a user is a member.
|
||||||
func (p *Authenticate) GetProfile(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) GetProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
// The sso proxy sends the user's email to this endpoint to get a list of Google groups that
|
// The sso proxy sends the user's email to this endpoint to get a list of Google groups that
|
||||||
// the email is a member of. The proxy will compare these groups to the list of allowed
|
// the email is a member of. The proxy will compare these groups to the list of allowed
|
||||||
// groups for the upstream service the user is trying to access.
|
// groups for the upstream service the user is trying to access.
|
||||||
|
|
||||||
email := req.FormValue("email")
|
email := r.FormValue("email")
|
||||||
if email == "" {
|
if email == "" {
|
||||||
http.Error(rw, "no email address included", http.StatusBadRequest)
|
http.Error(w, "no email address included", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// groupsFormValue := req.FormValue("groups")
|
// groupsFormValue := r.FormValue("groups")
|
||||||
// allowedGroups := []string{}
|
// allowedGroups := []string{}
|
||||||
// if groupsFormValue != "" {
|
// if groupsFormValue != "" {
|
||||||
// allowedGroups = strings.Split(groupsFormValue, ",")
|
// allowedGroups = strings.Split(groupsFormValue, ",")
|
||||||
|
@ -597,7 +598,7 @@ func (p *Authenticate) GetProfile(rw http.ResponseWriter, req *http.Request) {
|
||||||
// groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
|
// groups, err := p.provider.ValidateGroupMembership(email, allowedGroups)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// log.Error().Err(err).Msg("authenticate.GetProfile : error retrieving groups")
|
// log.Error().Err(err).Msg("authenticate.GetProfile : error retrieving groups")
|
||||||
// httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
// httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@ -609,26 +610,26 @@ func (p *Authenticate) GetProfile(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(response)
|
jsonBytes, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, fmt.Sprintf("error marshaling response: %s", err.Error()), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("error marshaling response: %s", err.Error()), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rw.Header().Set("GAP-Auth", email)
|
w.Header().Set("GAP-Auth", email)
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
rw.Write(jsonBytes)
|
w.Write(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateToken validates the X-Access-Token from the header and returns an error response
|
// ValidateToken validates the X-Access-Token from the header and returns an error response
|
||||||
// if it's invalid
|
// if it's invalid
|
||||||
func (p *Authenticate) ValidateToken(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticate) ValidateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
accessToken := req.Header.Get("X-Access-Token")
|
accessToken := r.Header.Get("X-Access-Token")
|
||||||
idToken := req.Header.Get("X-Id-Token")
|
idToken := r.Header.Get("X-Id-Token")
|
||||||
|
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
rw.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if idToken == "" {
|
if idToken == "" {
|
||||||
rw.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -638,10 +639,10 @@ func (p *Authenticate) ValidateToken(rw http.ResponseWriter, req *http.Request)
|
||||||
})
|
})
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
rw.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ services:
|
||||||
- /var/run/docker.sock:/tmp/docker.sock:ro
|
- /var/run/docker.sock:/tmp/docker.sock:ro
|
||||||
|
|
||||||
pomerium-authenticate:
|
pomerium-authenticate:
|
||||||
image: pomerium/pomerium:latest
|
image: pomerium/pomerium:latest # or `build: .` to build from source
|
||||||
environment:
|
environment:
|
||||||
- SERVICES=authenticate
|
- SERVICES=authenticate
|
||||||
# auth settings
|
# auth settings
|
||||||
|
@ -57,7 +57,7 @@ services:
|
||||||
- 443
|
- 443
|
||||||
|
|
||||||
pomerium-proxy:
|
pomerium-proxy:
|
||||||
image: pomerium/pomerium:latest
|
image: pomerium/pomerium:latest # or `build: .` to build from source
|
||||||
environment:
|
environment:
|
||||||
- SERVICES=proxy
|
- SERVICES=proxy
|
||||||
# proxy settings
|
# proxy settings
|
||||||
|
@ -66,6 +66,8 @@ services:
|
||||||
# Generate 256 bit random keys e.g. `head -c32 /dev/urandom | base64`
|
# Generate 256 bit random keys e.g. `head -c32 /dev/urandom | base64`
|
||||||
- SHARED_SECRET=aDducXQzK2tPY3R4TmdqTGhaYS80eGYxcTUvWWJDb2M=
|
- SHARED_SECRET=aDducXQzK2tPY3R4TmdqTGhaYS80eGYxcTUvWWJDb2M=
|
||||||
- COOKIE_SECRET=V2JBZk0zWGtsL29UcFUvWjVDWWQ2UHExNXJ0b2VhcDI=
|
- COOKIE_SECRET=V2JBZk0zWGtsL29UcFUvWjVDWWQ2UHExNXJ0b2VhcDI=
|
||||||
|
# If set, a JWT based signature is appended to each request header `x-pomerium-jwt-assertion`
|
||||||
|
# - SIGNING_KEY=LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=
|
||||||
|
|
||||||
# if passing certs as files
|
# if passing certs as files
|
||||||
# - CERTIFICATE_KEY=corp.beyondperimeter.com.crt
|
# - CERTIFICATE_KEY=corp.beyondperimeter.com.crt
|
||||||
|
|
|
@ -9,7 +9,7 @@ module.exports = {
|
||||||
docsDir: "docs",
|
docsDir: "docs",
|
||||||
editLinkText: "Edit this page on GitHub",
|
editLinkText: "Edit this page on GitHub",
|
||||||
lastUpdated: "Last Updated",
|
lastUpdated: "Last Updated",
|
||||||
nav: [{text: "Guide", link: "/guide/"}],
|
nav: [{ text: "Guide", link: "/guide/" }],
|
||||||
sidebar: {
|
sidebar: {
|
||||||
"/guide/": genSidebarConfig("Guide")
|
"/guide/": genSidebarConfig("Guide")
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ function genSidebarConfig(title) {
|
||||||
{
|
{
|
||||||
title,
|
title,
|
||||||
collapsable: false,
|
collapsable: false,
|
||||||
children: ["", "identity-providers"]
|
children: ["", "identity-providers", "signed-headers"]
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
113
docs/guide/signed-headers.md
Normal file
113
docs/guide/signed-headers.md
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
---
|
||||||
|
title: Signed Headers
|
||||||
|
description: >-
|
||||||
|
This article describes how to secure your app with signed headers. When
|
||||||
|
configured, pomerium uses JSON Web Tokens (JWT) to make sure that a request to
|
||||||
|
your app is authorized.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Securing your app with signed headers
|
||||||
|
|
||||||
|
This page describes how to secure your app with signed headers. When configured, pomerium uses JSON Web Tokens (JWT) to make sure that a request to your app is authorized.
|
||||||
|
|
||||||
|
::: warning
|
||||||
|
|
||||||
|
Health checks don't include JWT headers and pomerium doesn't handle health checks. If your health check returns access errors, make sure that you have it configured correctly and that your JWT header validation whitelists the health check path.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
To secure your app with signed headers, you'll need the following:
|
||||||
|
|
||||||
|
- An application you want users to connect to.
|
||||||
|
- A [JWT] library that supports the `ES256` algorithm.
|
||||||
|
|
||||||
|
## Rationale
|
||||||
|
|
||||||
|
Signed headers provide **secondary** security in case someone bypasses mTLS and network segmentation. This protects your app from the following kind of risks:
|
||||||
|
|
||||||
|
- Pomerium is accidentally disabled;
|
||||||
|
- Misconfigured firewalls;
|
||||||
|
- Mutually-authenticated TLS;
|
||||||
|
- Access from within the project.
|
||||||
|
|
||||||
|
To properly secure your app, you must use signed headers for all app types.
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
To secure your app with JWT, cryptographically verify the header, payload, and signature of the JWT. The JWT is in the HTTP request header `x-pomerium-iap-jwt-assertion`. If an attacker bypasses pomerium, they can forge the unsigned identity headers, `x-pomerium-authenticated-user-{email,id}`. JWT provides a more secure alternative.
|
||||||
|
|
||||||
|
Note that pomerium it strips the `x-pomerium-*` headers provided by the client when the request goes through the serving infrastructure.
|
||||||
|
|
||||||
|
Verify that the JWT's header conforms to the following constraints:
|
||||||
|
|
||||||
|
[JWT] | description
|
||||||
|
:-----: | ---------------------------------------------------------------------------------------------------
|
||||||
|
`exp` | Expiration time in seconds since the UNIX epoch. Allow 1 minute for skew.
|
||||||
|
`iat` | Issued-at time in seconds since the UNIX epoch. Allow 1 minute for skew.
|
||||||
|
`aud` | The client's final domain e.g. `httpbin.corp.example.com`.
|
||||||
|
`iss` | Issuer must be `pomerium-proxy`.
|
||||||
|
`sub` | Subject is the user's id. Can be used instead of the `x-pomerium-authenticated-user-id` header.
|
||||||
|
`email` | Email is the user's email. Can be used instead of the `x-pomerium-authenticated-user-email` header.
|
||||||
|
|
||||||
|
### Manual verification
|
||||||
|
|
||||||
|
Though you will very likely be verifying signed-headers programmatically in your application's middleware, and using a third-party JWT library, if you are new to JWT it may be helpful to show what manual verification looks like. The following guide assumes you are using the provided [docker-compose.yml] as a base and [httpbin]. Httpbin gives us a convienient way of inspecting client headers.
|
||||||
|
|
||||||
|
1. Provide pomerium with a base64 encoded Elliptic Curve ([NIST P-256] aka [secp256r1] aka prime256v1) Private Key. In production, you'd likely want to get these from your KMS.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# see ./scripts/generate_self_signed_signing_key.sh
|
||||||
|
openssl ecparam -genkey -name prime256v1 -noout -out ec_private.pem
|
||||||
|
openssl req -x509 -new -key ec_private.pem -days 1000000 -out ec_public.pem -subj "/CN=unused"
|
||||||
|
# careful! this will output your private key in terminal
|
||||||
|
cat ec_private.pem | base64
|
||||||
|
```
|
||||||
|
|
||||||
|
Copy the base64 encoded value of your private key to `pomerium-proxy`'s environmental configuration variable `SIGNING_KEY`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
SIGNING_KEY=ZxqyyIPPX0oWrrOwsxXgl0hHnTx3mBVhQ2kvW1YB4MM=
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Reload `pomerium-proxy`. Navigate to httpbin (by default, `https://httpbin.corp.${YOUR-DOMAIN}.com`), and login as usual. Click **request inspection**. Select `/headers'. Click **try it out** and then **execute**. You should see something like the following.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
3. `X-Pomerium-Jwt-Assertion` is the signature value. It's less scary than it looks and basically just a compressed, json blob as described above. Navigate to [jwt.io] which provides a helpful GUI to manually verify JWT values.
|
||||||
|
|
||||||
|
4. Paste the value of `X-Pomerium-Jwt-Assertion` header token into the `Encoded` form. You should notice that the decoded values look much more familiar.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
5. Finally, we want to cryptographically verify the validity of the token. To do this, we will need the signer's public key. You can simply copy and past the output of `cat ec_public.pem`.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**Viola!** Hopefully walking through a manual verification has helped give you a better feel for how signed JWT tokens are used as a secondary validation mechanism in pomerium.
|
||||||
|
|
||||||
|
::: warning
|
||||||
|
|
||||||
|
In an actual client, you'll want to ensure that all the other claims values are valid (like expiration, issuer, audience and so on) in the context of your application. You'll also want to make sure you have a safe and reliable mechanism for distributing pomerium-proxy's public signing key to client apps (typically, a [key management service]).
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
### Automatic verification
|
||||||
|
|
||||||
|
In the future, we will be adding example client implementations for:
|
||||||
|
|
||||||
|
- Python
|
||||||
|
- Go
|
||||||
|
- Java
|
||||||
|
- C#
|
||||||
|
- PHP
|
||||||
|
|
||||||
|
[developer tools]: https://developers.google.com/web/tools/chrome-devtools/open
|
||||||
|
[docker-compose.yml]: https://github.com/pomerium/pomerium/blob/master/docker-compose.yml
|
||||||
|
[httpbin]: https://httpbin.org/
|
||||||
|
[jwt]: https://jwt.io/introduction/
|
||||||
|
[jwt.io]: https://jwt.io/
|
||||||
|
[key management service]: https://en.wikipedia.org/wiki/Key_management
|
||||||
|
[nist p-256]: https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf
|
||||||
|
[secp256r1]: https://wiki.openssl.org/index.php/Command_Line_Elliptic_Curve_Operations
|
BIN
docs/guide/signed-headers/inspect-headers.png
Normal file
BIN
docs/guide/signed-headers/inspect-headers.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 450 KiB |
BIN
docs/guide/signed-headers/verifying-headers-1.png
Normal file
BIN
docs/guide/signed-headers/verifying-headers-1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 164 KiB |
BIN
docs/guide/signed-headers/verifying-headers-2.png
Normal file
BIN
docs/guide/signed-headers/verifying-headers-2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 386 KiB |
|
@ -19,6 +19,10 @@ export ALLOWED_DOMAINS=*
|
||||||
# Generate 256 bit random keys e.g. `head -c32 /dev/urandom | base64`
|
# Generate 256 bit random keys e.g. `head -c32 /dev/urandom | base64`
|
||||||
export SHARED_SECRET=9wiTZq4qvmS/plYQyvzGKWPlH/UBy0DMYMA2x/zngrM=
|
export SHARED_SECRET=9wiTZq4qvmS/plYQyvzGKWPlH/UBy0DMYMA2x/zngrM=
|
||||||
export COOKIE_SECRET=uPGHo1ujND/k3B9V6yr52Gweq3RRYfFho98jxDG5Br8=
|
export COOKIE_SECRET=uPGHo1ujND/k3B9V6yr52Gweq3RRYfFho98jxDG5Br8=
|
||||||
|
# If set, a JWT based signature is appended to each request header `x-pomerium-jwt-assertion`
|
||||||
|
# export SIGNING_KEY="Replace with base64'd private key from ./scripts/self-signed-sign-key.sh"
|
||||||
|
|
||||||
|
# Identity Provider Settings
|
||||||
|
|
||||||
# OKTA
|
# OKTA
|
||||||
# export IDP_PROVIDER="okta
|
# export IDP_PROVIDER="okta
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -9,11 +9,12 @@ require (
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
github.com/rs/zerolog v1.11.0
|
github.com/rs/zerolog v1.11.0
|
||||||
github.com/stretchr/testify v1.2.2 // indirect
|
github.com/stretchr/testify v1.2.2 // indirect
|
||||||
|
github.com/zenazn/goji v0.9.0
|
||||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
||||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
||||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
|
||||||
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect
|
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect
|
||||||
google.golang.org/appengine v1.4.0 // indirect
|
google.golang.org/appengine v1.4.0 // indirect
|
||||||
gopkg.in/square/go-jose.v2 v2.2.1 // indirect
|
gopkg.in/square/go-jose.v2 v2.2.1
|
||||||
)
|
)
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -16,6 +16,8 @@ github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
|
||||||
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
github.com/zenazn/goji v0.9.0 h1:RSQQAbXGArQ0dIDEq+PI6WqN6if+5KHu6x2Cx/GXLTQ=
|
||||||
|
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
||||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
|
102
internal/cryptutil/marshal.go
Normal file
102
internal/cryptutil/marshal.go
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
// Package cryptutil provides encoding and decoding routines for various cryptographic structures.
|
||||||
|
package cryptutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
|
||||||
|
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
|
||||||
|
block, _ := pem.Decode(encodedKey)
|
||||||
|
if block == nil || block.Type != "PUBLIC KEY" {
|
||||||
|
return nil, fmt.Errorf("marshal: could not decode PEM block type %s", block.Type)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ecdsaPub, ok := pub.(*ecdsa.PublicKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("marshal: data was not an ECDSA public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ecdsaPub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodePublicKey encodes an ECDSA public key to PEM format.
|
||||||
|
func EncodePublicKey(key *ecdsa.PublicKey) ([]byte, error) {
|
||||||
|
derBytes, err := x509.MarshalPKIXPublicKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block := &pem.Block{
|
||||||
|
Type: "PUBLIC KEY",
|
||||||
|
Bytes: derBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pem.EncodeToMemory(block), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodePrivateKey decodes a PEM-encoded ECDSA private key.
|
||||||
|
func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) {
|
||||||
|
var skippedTypes []string
|
||||||
|
var block *pem.Block
|
||||||
|
|
||||||
|
for {
|
||||||
|
block, encodedKey = pem.Decode(encodedKey)
|
||||||
|
|
||||||
|
if block == nil {
|
||||||
|
return nil, fmt.Errorf("failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.Type == "EC PRIVATE KEY" {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
skippedTypes = append(skippedTypes, block.Type)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return privKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodePrivateKey encodes an ECDSA private key to PEM format.
|
||||||
|
func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||||
|
derKey, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBlock := &pem.Block{
|
||||||
|
Type: "EC PRIVATE KEY",
|
||||||
|
Bytes: derKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pem.EncodeToMemory(keyBlock), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeSignatureJWT encodes an ECDSA signature according to
|
||||||
|
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||||
|
func EncodeSignatureJWT(sig []byte) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeSignatureJWT decodes an ECDSA signature according to
|
||||||
|
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||||
|
func DecodeSignatureJWT(b64sig string) ([]byte, error) {
|
||||||
|
return base64.RawURLEncoding.DecodeString(b64sig)
|
||||||
|
}
|
122
internal/cryptutil/marshal_test.go
Normal file
122
internal/cryptutil/marshal_test.go
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package cryptutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A keypair for NIST P-256 / secp256r1
|
||||||
|
// Generated using:
|
||||||
|
// openssl ecparam -genkey -name prime256v1 -outform PEM
|
||||||
|
var pemECPrivateKeyP256 = `-----BEGIN EC PARAMETERS-----
|
||||||
|
BggqhkjOPQMBBw==
|
||||||
|
-----END EC PARAMETERS-----
|
||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEIOI+EZsjyN3jvWJI/KDihFmqTuDpUe/if6f/pgGTBta/oAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvXnFUSTwqQFo5gbfIlP+gvEYba
|
||||||
|
+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||||
|
-----END EC PRIVATE KEY-----
|
||||||
|
`
|
||||||
|
|
||||||
|
var pemECPublicKeyP256 = `-----BEGIN PUBLIC KEY-----
|
||||||
|
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvX
|
||||||
|
nFUSTwqQFo5gbfIlP+gvEYba+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||||
|
-----END PUBLIC KEY-----
|
||||||
|
`
|
||||||
|
|
||||||
|
// A keypair for NIST P-384 / secp384r1
|
||||||
|
// Generated using:
|
||||||
|
// openssl ecparam -genkey -name secp384r1 -outform PEM
|
||||||
|
var pemECPrivateKeyP384 = `-----BEGIN EC PARAMETERS-----
|
||||||
|
BgUrgQQAIg==
|
||||||
|
-----END EC PARAMETERS-----
|
||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MIGkAgEBBDAhA0YPVL1kimIy+FAqzUAtmR3It2Yjv2I++YpcC4oX7wGuEWcWKBYE
|
||||||
|
oOjj7wG/memgBwYFK4EEACKhZANiAAQub8xaaCTTW5rCHJCqUddIXpvq/TxdwViH
|
||||||
|
+tPEQQlJAJciXStM/aNLYA7Q1K1zMjYyzKSWz5kAh/+x4rXQ9Hlm3VAwCQDVVSjP
|
||||||
|
bfiNOXKOWfmyrGyQ7fQfs+ro1lmjLjs=
|
||||||
|
-----END EC PRIVATE KEY-----
|
||||||
|
`
|
||||||
|
|
||||||
|
var pemECPublicKeyP384 = `-----BEGIN PUBLIC KEY-----
|
||||||
|
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAELm/MWmgk01uawhyQqlHXSF6b6v08XcFY
|
||||||
|
h/rTxEEJSQCXIl0rTP2jS2AO0NStczI2Msykls+ZAIf/seK10PR5Zt1QMAkA1VUo
|
||||||
|
z234jTlyjln5sqxskO30H7Pq6NZZoy47
|
||||||
|
-----END PUBLIC KEY-----
|
||||||
|
`
|
||||||
|
|
||||||
|
var garbagePEM = `-----BEGIN GARBAGE-----
|
||||||
|
TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQ=
|
||||||
|
-----END GARBAGE-----
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestPublicKeyMarshaling(t *testing.T) {
|
||||||
|
ecKey, err := DecodePublicKey([]byte(pemECPublicKeyP256))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes, _ := EncodePublicKey(ecKey)
|
||||||
|
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
|
||||||
|
t.Fatal("public key encoding did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrivateKeyBadDecode(t *testing.T) {
|
||||||
|
_, err := DecodePrivateKey([]byte(garbagePEM))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("decoded garbage data without complaint")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrivateKeyMarshaling(t *testing.T) {
|
||||||
|
ecKey, err := DecodePrivateKey([]byte(pemECPrivateKeyP256))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes, _ := EncodePrivateKey(ecKey)
|
||||||
|
if !strings.HasSuffix(pemECPrivateKeyP256, string(pemBytes)) {
|
||||||
|
t.Fatal("private key encoding did not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test vector from https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||||
|
var jwtTest = []struct {
|
||||||
|
sigBytes []byte
|
||||||
|
b64sig string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sigBytes: []byte{14, 209, 33, 83, 121, 99, 108, 72, 60, 47, 127, 21,
|
||||||
|
88, 7, 212, 2, 163, 178, 40, 3, 58, 249, 124, 126, 23, 129, 154, 195, 22, 158,
|
||||||
|
166, 101, 197, 10, 7, 211, 140, 60, 112, 229, 216, 241, 45, 175,
|
||||||
|
8, 74, 84, 128, 166, 101, 144, 197, 242, 147, 80, 154, 143, 63, 127, 138, 131,
|
||||||
|
163, 84, 213},
|
||||||
|
b64sig: "DtEhU3ljbEg8L38VWAfUAqOyKAM6-Xx-F4GawxaepmXFCgfTjDxw5djxLa8ISlSApmWQxfKTUJqPP3-Kg6NU1Q",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTEncoding(t *testing.T) {
|
||||||
|
for _, tt := range jwtTest {
|
||||||
|
result := EncodeSignatureJWT(tt.sigBytes)
|
||||||
|
|
||||||
|
if strings.Compare(result, tt.b64sig) != 0 {
|
||||||
|
t.Fatalf("expected %s, got %s\n", tt.b64sig, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTDecoding(t *testing.T) {
|
||||||
|
for _, tt := range jwtTest {
|
||||||
|
resultSig, err := DecodeSignatureJWT(tt.b64sig)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(resultSig, tt.sigBytes) {
|
||||||
|
t.Fatalf("decoded signature was incorrect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
84
internal/cryptutil/sign.go
Normal file
84
internal/cryptutil/sign.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
|
||||||
|
// https://tools.ietf.org/html/rfc7519
|
||||||
|
type JWTSigner interface {
|
||||||
|
SignJWT(string, string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ES256Signer is struct containing the required fields to create a ES256 signed JSON Web Tokens
|
||||||
|
type ES256Signer struct {
|
||||||
|
// User (sub) is unique, stable identifier for the user.
|
||||||
|
// Use in place of the x-pomerium-authenticated-user-id header.
|
||||||
|
User string `json:"sub,omitempty"`
|
||||||
|
// Email (sub) is a **private** claim name identifier for the user email address.
|
||||||
|
// Use in place of the x-pomerium-authenticated-user-email header.
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
// Audience (aud) must be the destination of the upstream proxy locations.
|
||||||
|
// e.g. `helloworld.corp.example.com`
|
||||||
|
Audience jwt.Audience `json:"aud,omitempty"`
|
||||||
|
// Issuer (iss) is the URL of the proxy.
|
||||||
|
// e.g. `proxy.corp.example.com`
|
||||||
|
Issuer string `json:"iss,omitempty"`
|
||||||
|
// Expiry (exp) is the expiration time in seconds since the UNIX epoch.
|
||||||
|
// Allow 1 minute for skew. The maximum lifetime of a token is 10 minutes + 2 * skew.
|
||||||
|
Expiry jwt.NumericDate `json:"exp,omitempty"`
|
||||||
|
// IssuedAt (iat) is the time is measured in seconds since the UNIX epoch.
|
||||||
|
// Allow 1 minute for skew.
|
||||||
|
IssuedAt jwt.NumericDate `json:"iat,omitempty"`
|
||||||
|
// IssuedAt (nbf) is the time is measured in seconds since the UNIX epoch.
|
||||||
|
// Allow 1 minute for skew.
|
||||||
|
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
|
||||||
|
|
||||||
|
signer jose.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewES256Signer creates an Eliptic Curve, NIST P-256 (aka secp256r1 aka prime256v1) JWT signer.
|
||||||
|
//
|
||||||
|
// RSA is not supported due to performance considerations of needing to sign each request.
|
||||||
|
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
|
||||||
|
// to length extension attacks.
|
||||||
|
// See also:
|
||||||
|
// - https://cloud.google.com/iot/docs/how-tos/credentials/keys
|
||||||
|
func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) {
|
||||||
|
key, err := DecodePrivateKey(privKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("internal/cryptutil parsing key failed %v", err)
|
||||||
|
}
|
||||||
|
signer, err := jose.NewSigner(
|
||||||
|
jose.SigningKey{
|
||||||
|
Algorithm: jose.ES256, // ECDSA using P-256 and SHA-256
|
||||||
|
Key: key,
|
||||||
|
},
|
||||||
|
(&jose.SignerOptions{}).WithType("JWT"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("internal/cryptutil new signer failed %v", err)
|
||||||
|
}
|
||||||
|
return &ES256Signer{
|
||||||
|
Issuer: "pomerium-proxy",
|
||||||
|
Audience: jwt.Audience{audience},
|
||||||
|
signer: signer,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignJWT creates a signed JWT containing claims for the logged in user id (`sub`) and email (`email`).
|
||||||
|
func (s *ES256Signer) SignJWT(user, email string) (string, error) {
|
||||||
|
s.User = user
|
||||||
|
s.Email = email
|
||||||
|
now := time.Now()
|
||||||
|
s.IssuedAt = jwt.NewNumericDate(now)
|
||||||
|
s.Expiry = jwt.NewNumericDate(now.Add(jwt.DefaultLeeway))
|
||||||
|
s.NotBefore = jwt.NewNumericDate(now.Add(-1 * jwt.DefaultLeeway))
|
||||||
|
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return rawJWT, nil
|
||||||
|
}
|
44
internal/cryptutil/sign_test.go
Normal file
44
internal/cryptutil/sign_test.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package cryptutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestES256Signer(t *testing.T) {
|
||||||
|
signer, err := NewES256Signer([]byte(pemECPrivateKeyP256), "destination-url")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if signer == nil {
|
||||||
|
t.Fatal("signer should not be nil")
|
||||||
|
}
|
||||||
|
rawJwt, err := signer.SignJWT("joe-user", "joe-user@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if rawJwt == "" {
|
||||||
|
t.Fatal("jwt should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewES256Signer(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
privKey []byte
|
||||||
|
audience string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"working example", []byte(pemECPrivateKeyP256), "some-domain.com", false},
|
||||||
|
{"bad private key", []byte(garbagePEM), "some-domain.com", true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := NewES256Signer(tt.privKey, tt.audience)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewES256Signer() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -130,6 +130,8 @@ func readCertificateFile(certFile, certKeyFile string) (*tls.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDefaultTLSConfig creates a new TLS config based on the certificate files given.
|
// newDefaultTLSConfig creates a new TLS config based on the certificate files given.
|
||||||
|
// see also:
|
||||||
|
// https://wiki.mozilla.org/Security/Server_Side_TLS#Recommended_configurations
|
||||||
func newDefaultTLSConfig(cert *tls.Certificate) (*tls.Config, error) {
|
func newDefaultTLSConfig(cert *tls.Certificate) (*tls.Config, error) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
CipherSuites: []uint16{
|
CipherSuites: []uint16{
|
||||||
|
|
212
internal/log/handler_log.go
Normal file
212
internal/log/handler_log.go
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
// Package log provides a set of http.Handler helpers for zerolog.
|
||||||
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/zenazn/goji/web/mutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FromRequest gets the logger in the request's context.
|
||||||
|
// This is a shortcut for log.Ctx(r.Context())
|
||||||
|
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||||
|
return Ctx(r.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler injects log into requests context.
|
||||||
|
func NewHandler(log zerolog.Logger) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Create a copy of the logger (including internal context slice)
|
||||||
|
// to prevent data race when using UpdateContext.
|
||||||
|
l := log.With().Logger()
|
||||||
|
r = r.WithContext(l.WithContext(r.Context()))
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLHandler adds the requested URL as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
|
func URLHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, r.URL.String())
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodHandler adds the request method as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
|
func MethodHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, r.Method)
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestHandler adds the request method and URL as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
|
func RequestHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, r.Method+" "+r.URL.String())
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddrHandler adds the request's remote address as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
|
func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, host)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAgentHandler adds the request's user-agent as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
|
func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if ua := r.Header.Get("User-Agent"); ua != "" {
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, ua)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefererHandler adds the request's referer as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
|
func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if ref := r.Header.Get("Referer"); ref != "" {
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, ref)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type idKey struct{}
|
||||||
|
|
||||||
|
// IDFromRequest returns the unique id associated to the request if any.
|
||||||
|
func IDFromRequest(r *http.Request) (id string, ok bool) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return IDFromCtx(r.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDFromCtx returns the unique id associated to the context if any.
|
||||||
|
func IDFromCtx(ctx context.Context) (id string, ok bool) {
|
||||||
|
id, ok = ctx.Value(idKey{}).(string)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestIDHandler returns a handler setting a unique id to the request which can
|
||||||
|
// be gathered using IDFromRequest(req). This generated id is added as a field to the
|
||||||
|
// logger using the passed fieldKey as field name. The id is also added as a response
|
||||||
|
// header if the headerName is not empty.
|
||||||
|
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
id, ok := IDFromRequest(r)
|
||||||
|
if !ok {
|
||||||
|
id = uuid()
|
||||||
|
ctx = context.WithValue(ctx, idKey{}, id)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
if fieldKey != "" {
|
||||||
|
log := zerolog.Ctx(ctx)
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if headerName != "" {
|
||||||
|
w.Header().Set(headerName, id)
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessHandler returns a handler that call f after each request.
|
||||||
|
func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
lw := mutil.WrapWriter(w)
|
||||||
|
next.ServeHTTP(lw, r)
|
||||||
|
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardedAddrHandler returns the client IP address from a request. If present, the
|
||||||
|
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
||||||
|
// rightmost entry (the client IP that connected to the LB) is returned.
|
||||||
|
func ForwardedAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
addr := r.RemoteAddr
|
||||||
|
if ra := r.Header.Get("X-Forwarded-For"); ra != "" {
|
||||||
|
forwardedList := strings.Split(ra, ",")
|
||||||
|
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
||||||
|
if forwardedAddr != "" {
|
||||||
|
addr = forwardedAddr
|
||||||
|
}
|
||||||
|
log := zerolog.Ctx(r.Context())
|
||||||
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str(fieldKey, addr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// uuid generates a random 128-bit non-RFC UUID.
|
||||||
|
func uuid() string {
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:])
|
||||||
|
}
|
260
internal/log/handler_log_test.go
Normal file
260
internal/log/handler_log_test.go
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
// Package log provides a set of http.Handler helpers for zerolog.
|
||||||
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateUUID(t *testing.T) {
|
||||||
|
prev := uuid()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
id := uuid()
|
||||||
|
if id == "" {
|
||||||
|
t.Fatal("random pool failure")
|
||||||
|
}
|
||||||
|
if prev == id {
|
||||||
|
t.Fatalf("Should get a new ID!")
|
||||||
|
}
|
||||||
|
matched, err := regexp.MatchString("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}", id)
|
||||||
|
if !matched || err != nil {
|
||||||
|
t.Fatalf("expected match %s %v %s", id, matched, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeIfBinary(out *bytes.Buffer) string {
|
||||||
|
// p := out.Bytes()
|
||||||
|
// if len(p) == 0 || p[0] < 0x7F {
|
||||||
|
// return out.String()
|
||||||
|
// }
|
||||||
|
return out.String() //cbor.DecodeObjectToStr(p) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHandler(t *testing.T) {
|
||||||
|
log := zerolog.New(nil).With().
|
||||||
|
Str("foo", "bar").
|
||||||
|
Logger()
|
||||||
|
lh := NewHandler(log)
|
||||||
|
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
if !reflect.DeepEqual(*l, log) {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
h.ServeHTTP(nil, &http.Request{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURLHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||||
|
}
|
||||||
|
h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Method: "POST",
|
||||||
|
}
|
||||||
|
h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Method: "POST",
|
||||||
|
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||||
|
}
|
||||||
|
h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoteAddrHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
RemoteAddr: "1.2.3.4:1234",
|
||||||
|
}
|
||||||
|
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoteAddrHandlerIPv6(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
RemoteAddr: "[2001:db8:a0b:12f0::1]:1234",
|
||||||
|
}
|
||||||
|
h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserAgentHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Header: http.Header{
|
||||||
|
"User-Agent": []string{"some user agent string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"ua":"some user agent string"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefererHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Header: http.Header{
|
||||||
|
"Referer": []string{"http://foo.com/bar"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"referer":"http://foo.com/bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestIDHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Header: http.Header{
|
||||||
|
"Referer": []string{"http://foo.com/bar"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id, ok := IDFromRequest(r)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Missing id in request")
|
||||||
|
}
|
||||||
|
// if want, got := id.String(), w.Header().Get("Request-Id"); got != want {
|
||||||
|
// t.Errorf("Invalid Request-Id header, got: %s, want: %s", got, want)
|
||||||
|
// }
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombinedHandlers(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Method: "POST",
|
||||||
|
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||||
|
}
|
||||||
|
h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHandlers(b *testing.B) {
|
||||||
|
r := &http.Request{
|
||||||
|
Method: "POST",
|
||||||
|
URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
|
||||||
|
}
|
||||||
|
h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h2 := MethodHandler("method")(RequestHandler("request")(h1))
|
||||||
|
handlers := map[string]http.Handler{
|
||||||
|
"Single": NewHandler(zerolog.New(ioutil.Discard))(h1),
|
||||||
|
"Combined": NewHandler(zerolog.New(ioutil.Discard))(h2),
|
||||||
|
"SingleDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1),
|
||||||
|
"CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2),
|
||||||
|
}
|
||||||
|
for name := range handlers {
|
||||||
|
h := handlers[name]
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDataRace(b *testing.B) {
|
||||||
|
log := zerolog.New(nil).With().
|
||||||
|
Str("foo", "bar").
|
||||||
|
Logger()
|
||||||
|
lh := NewHandler(log)
|
||||||
|
h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str("bar", "baz")
|
||||||
|
})
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
h.ServeHTTP(nil, &http.Request{})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -2,7 +2,7 @@
|
||||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -21,19 +21,6 @@ func With() zerolog.Context {
|
||||||
return Logger.With()
|
return Logger.With()
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithRequest creates a child logger with the remote user added to its context.
|
|
||||||
func WithRequest(req *http.Request, function string) zerolog.Logger {
|
|
||||||
remoteUser := getRemoteAddr(req)
|
|
||||||
return Logger.With().
|
|
||||||
Str("function", function).
|
|
||||||
Str("req-remote-user", remoteUser).
|
|
||||||
Str("req-http-method", req.Method).
|
|
||||||
Str("req-host", req.Host).
|
|
||||||
Str("req-url", req.URL.String()).
|
|
||||||
// Str("req-user-agent", req.Header.Get("User-Agent")).
|
|
||||||
Logger()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Level creates a child logger with the minimum accepted level set to level.
|
// Level creates a child logger with the minimum accepted level set to level.
|
||||||
func Level(level zerolog.Level) zerolog.Logger {
|
func Level(level zerolog.Level) zerolog.Logger {
|
||||||
return Logger.Level(level)
|
return Logger.Level(level)
|
||||||
|
@ -109,3 +96,9 @@ func Print(v ...interface{}) {
|
||||||
func Printf(format string, v ...interface{}) {
|
func Printf(format string, v ...interface{}) {
|
||||||
Logger.Printf(format, v...)
|
Logger.Printf(format, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ctx returns the Logger associated with the ctx. If no logger
|
||||||
|
// is associated, a disabled logger is returned.
|
||||||
|
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||||
|
return zerolog.Ctx(ctx)
|
||||||
|
}
|
||||||
|
|
|
@ -1,145 +0,0 @@
|
||||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Used to stash the authenticated user in the response for access when logging requests.
|
|
||||||
const loggingUserHeader = "SSO-Authenticated-User"
|
|
||||||
const gapMetaDataHeader = "GAP-Auth"
|
|
||||||
|
|
||||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
|
||||||
// code and body size
|
|
||||||
type responseLogger struct {
|
|
||||||
w http.ResponseWriter
|
|
||||||
status int
|
|
||||||
size int
|
|
||||||
proxyHost string
|
|
||||||
authInfo string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) Header() http.Header {
|
|
||||||
return l.w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) extractUser() {
|
|
||||||
authInfo := l.w.Header().Get(loggingUserHeader)
|
|
||||||
if authInfo != "" {
|
|
||||||
l.authInfo = authInfo
|
|
||||||
l.w.Header().Del(loggingUserHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) ExtractGAPMetadata() {
|
|
||||||
authInfo := l.w.Header().Get(gapMetaDataHeader)
|
|
||||||
if authInfo != "" {
|
|
||||||
l.authInfo = authInfo
|
|
||||||
|
|
||||||
l.w.Header().Del(gapMetaDataHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) Write(b []byte) (int, error) {
|
|
||||||
if l.status == 0 {
|
|
||||||
// The status will be StatusOK if WriteHeader has not been called yet
|
|
||||||
l.status = http.StatusOK
|
|
||||||
}
|
|
||||||
l.extractUser()
|
|
||||||
l.ExtractGAPMetadata()
|
|
||||||
|
|
||||||
size, err := l.w.Write(b)
|
|
||||||
l.size += size
|
|
||||||
return size, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) WriteHeader(s int) {
|
|
||||||
l.extractUser()
|
|
||||||
l.ExtractGAPMetadata()
|
|
||||||
|
|
||||||
l.w.WriteHeader(s)
|
|
||||||
l.status = s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) Status() int {
|
|
||||||
return l.status
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) Size() int {
|
|
||||||
return l.size
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *responseLogger) Flush() {
|
|
||||||
f := l.w.(http.Flusher)
|
|
||||||
f.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends
|
|
||||||
type loggingHandler struct {
|
|
||||||
handler http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLoggingHandler returns a new loggingHandler that wraps a handler, and writer.
|
|
||||||
func NewLoggingHandler(h http.Handler) http.Handler {
|
|
||||||
return loggingHandler{
|
|
||||||
handler: h,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
||||||
t := time.Now()
|
|
||||||
url := *req.URL
|
|
||||||
logger := &responseLogger{w: w, proxyHost: getProxyHost(req)}
|
|
||||||
h.handler.ServeHTTP(logger, req)
|
|
||||||
requestDuration := time.Since(t)
|
|
||||||
|
|
||||||
logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status())
|
|
||||||
}
|
|
||||||
|
|
||||||
// logRequest logs information about a request
|
|
||||||
func logRequest(proxyHost, username string, req *http.Request, url url.URL, requestDuration time.Duration, status int) {
|
|
||||||
uri := req.Host + url.RequestURI()
|
|
||||||
Info().
|
|
||||||
Int("http-status", status).
|
|
||||||
Str("request-method", req.Method).
|
|
||||||
Str("request-uri", uri).
|
|
||||||
Str("proxy-host", proxyHost).
|
|
||||||
// Str("user-agent", req.Header.Get("User-Agent")).
|
|
||||||
Str("remote-address", getRemoteAddr(req)).
|
|
||||||
Dur("duration", requestDuration).
|
|
||||||
Str("user", username).
|
|
||||||
Msg("request")
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRemoteAddr returns the client IP address from a request. If present, the
|
|
||||||
// X-Forwarded-For header is assumed to be set by a load balancer, and its
|
|
||||||
// rightmost entry (the client IP that connected to the LB) is returned.
|
|
||||||
func getRemoteAddr(req *http.Request) string {
|
|
||||||
addr := req.RemoteAddr
|
|
||||||
forwardedHeader := req.Header.Get("X-Forwarded-For")
|
|
||||||
if forwardedHeader != "" {
|
|
||||||
forwardedList := strings.Split(forwardedHeader, ",")
|
|
||||||
forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1])
|
|
||||||
if forwardedAddr != "" {
|
|
||||||
addr = forwardedAddr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProxyHost attempts to get the proxy host from the redirect_uri parameter
|
|
||||||
func getProxyHost(req *http.Request) string {
|
|
||||||
err := req.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
redirect := req.Form.Get("redirect_uri")
|
|
||||||
redirectURL, err := url.Parse(redirect)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return redirectURL.Host
|
|
||||||
}
|
|
|
@ -1,72 +0,0 @@
|
||||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetRemoteAddr(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
remoteAddr string
|
|
||||||
forwardedHeader string
|
|
||||||
expectedAddr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "RemoteAddr used when no X-Forwarded-For header is given",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
expectedAddr: "1.1.1.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "RemoteAddr used when no X-Forwarded-For header is only whitespace",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
forwardedHeader: " ",
|
|
||||||
expectedAddr: "1.1.1.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
forwardedHeader: " , , ",
|
|
||||||
expectedAddr: "1.1.1.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X-Forwarded-For header is preferred to RemoteAddr",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
forwardedHeader: "9.9.9.9",
|
|
||||||
expectedAddr: "9.9.9.9",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "rightmost entry in X-Forwarded-For header is used",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5",
|
|
||||||
expectedAddr: "5.5.5.5",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
forwardedHeader: "2.2.2.2, 3.3.3.3, ",
|
|
||||||
expectedAddr: "1.1.1.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X-Forwaded-For header entries are stripped",
|
|
||||||
remoteAddr: "1.1.1.1",
|
|
||||||
forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ",
|
|
||||||
expectedAddr: "5.5.5.5",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.RemoteAddr = tc.remoteAddr
|
|
||||||
if tc.forwardedHeader != "" {
|
|
||||||
req.Header.Set("X-Forwarded-For", tc.forwardedHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := getRemoteAddr(req)
|
|
||||||
if addr != tc.expectedAddr {
|
|
||||||
t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
109
internal/middleware/chain.go
Normal file
109
internal/middleware/chain.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Constructor is a type alias for func(http.Handler) http.Handler
|
||||||
|
type Constructor func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// Chain acts as a list of http.Handler constructors.
|
||||||
|
// Chain is effectively immutable:
|
||||||
|
// once created, it will always hold
|
||||||
|
// the same set of constructors in the same order.
|
||||||
|
type Chain struct {
|
||||||
|
constructors []Constructor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChain creates a new chain,
|
||||||
|
// memorizing the given list of middleware constructors.
|
||||||
|
// New serves no other function,
|
||||||
|
// constructors are only called upon a call to Then().
|
||||||
|
func NewChain(constructors ...Constructor) Chain {
|
||||||
|
return Chain{append(([]Constructor)(nil), constructors...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then chains the middleware and returns the final http.Handler.
|
||||||
|
// NewChain(m1, m2, m3).Then(h)
|
||||||
|
// is equivalent to:
|
||||||
|
// m1(m2(m3(h)))
|
||||||
|
// When the request comes in, it will be passed to m1, then m2, then m3
|
||||||
|
// and finally, the given handler
|
||||||
|
// (assuming every middleware calls the following one).
|
||||||
|
//
|
||||||
|
// A chain can be safely reused by calling Then() several times.
|
||||||
|
// stdStack := middleware.NewChain(ratelimitHandler, csrfHandler)
|
||||||
|
// indexPipe = stdStack.Then(indexHandler)
|
||||||
|
// authPipe = stdStack.Then(authHandler)
|
||||||
|
// Note that constructors are called on every call to Then()
|
||||||
|
// and thus several instances of the same middleware will be created
|
||||||
|
// when a chain is reused in this way.
|
||||||
|
// For proper middleware, this should cause no problems.
|
||||||
|
//
|
||||||
|
// Then() treats nil as http.DefaultServeMux.
|
||||||
|
func (c Chain) Then(h http.Handler) http.Handler {
|
||||||
|
if h == nil {
|
||||||
|
h = http.DefaultServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range c.constructors {
|
||||||
|
h = c.constructors[len(c.constructors)-1-i](h)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThenFunc works identically to Then, but takes
|
||||||
|
// a HandlerFunc instead of a Handler.
|
||||||
|
//
|
||||||
|
// The following two statements are equivalent:
|
||||||
|
// c.Then(http.HandlerFunc(fn))
|
||||||
|
// c.ThenFunc(fn)
|
||||||
|
//
|
||||||
|
// ThenFunc provides all the guarantees of Then.
|
||||||
|
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
|
||||||
|
if fn == nil {
|
||||||
|
return c.Then(nil)
|
||||||
|
}
|
||||||
|
return c.Then(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append extends a chain, adding the specified constructors
|
||||||
|
// as the last ones in the request flow.
|
||||||
|
//
|
||||||
|
// Append returns a new chain, leaving the original one untouched.
|
||||||
|
//
|
||||||
|
// stdChain := middleware.NewChain(m1, m2)
|
||||||
|
// extChain := stdChain.Append(m3, m4)
|
||||||
|
// // requests in stdChain go m1 -> m2
|
||||||
|
// // requests in extChain go m1 -> m2 -> m3 -> m4
|
||||||
|
func (c Chain) Append(constructors ...Constructor) Chain {
|
||||||
|
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
|
||||||
|
newCons = append(newCons, c.constructors...)
|
||||||
|
newCons = append(newCons, constructors...)
|
||||||
|
|
||||||
|
return Chain{newCons}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend extends a chain by adding the specified chain
|
||||||
|
// as the last one in the request flow.
|
||||||
|
//
|
||||||
|
// Extend returns a new chain, leaving the original one untouched.
|
||||||
|
//
|
||||||
|
// stdChain := middleware.NewChain(m1, m2)
|
||||||
|
// ext1Chain := middleware.NewChain(m3, m4)
|
||||||
|
// ext2Chain := stdChain.Extend(ext1Chain)
|
||||||
|
// // requests in stdChain go m1 -> m2
|
||||||
|
// // requests in ext1Chain go m3 -> m4
|
||||||
|
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
|
||||||
|
//
|
||||||
|
// Another example:
|
||||||
|
// aHtmlAfterNosurf := middleware.NewChain(m2)
|
||||||
|
// aHtml := middleware.NewChain(m1, func(h http.Handler) http.Handler {
|
||||||
|
// csrf := nosurf.NewChain(h)
|
||||||
|
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
|
||||||
|
// return csrf
|
||||||
|
// }).Extend(aHtmlAfterNosurf)
|
||||||
|
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
|
||||||
|
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
|
||||||
|
func (c Chain) Extend(chain Chain) Chain {
|
||||||
|
return c.Append(chain.constructors...)
|
||||||
|
}
|
177
internal/middleware/chain_test.go
Normal file
177
internal/middleware/chain_test.go
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A constructor for middleware
|
||||||
|
// that writes its own "tag" into the RW and does nothing else.
|
||||||
|
// Useful in checking if a chain is behaving in the right order.
|
||||||
|
func tagMiddleware(tag string) Constructor {
|
||||||
|
return func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(tag))
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
|
||||||
|
// but the best we can do.
|
||||||
|
func funcsEqual(f1, f2 interface{}) bool {
|
||||||
|
val1 := reflect.ValueOf(f1)
|
||||||
|
val2 := reflect.ValueOf(f2)
|
||||||
|
return val1.Pointer() == val2.Pointer()
|
||||||
|
}
|
||||||
|
|
||||||
|
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("app\n"))
|
||||||
|
})
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
c1 := func(h http.Handler) http.Handler {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c2 := func(h http.Handler) http.Handler {
|
||||||
|
return http.StripPrefix("potato", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
slice := []Constructor{c1, c2}
|
||||||
|
|
||||||
|
chain := NewChain(slice...)
|
||||||
|
for k := range slice {
|
||||||
|
if !funcsEqual(chain.constructors[k], slice[k]) {
|
||||||
|
t.Error("New does not add constructors correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
||||||
|
if !funcsEqual(NewChain().Then(testApp), testApp) {
|
||||||
|
t.Error("Then does not work with no middleware")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||||
|
if NewChain().Then(nil) != http.DefaultServeMux {
|
||||||
|
t.Error("Then does not treat nil as DefaultServeMux")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||||
|
if NewChain().ThenFunc(nil) != http.DefaultServeMux {
|
||||||
|
t.Error("ThenFunc does not treat nil as DefaultServeMux")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
chained := NewChain().ThenFunc(fn)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
chained.ServeHTTP(rec, (*http.Request)(nil))
|
||||||
|
|
||||||
|
if reflect.TypeOf(chained) != reflect.TypeOf((http.HandlerFunc)(nil)) {
|
||||||
|
t.Error("ThenFunc does not construct HandlerFunc")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
||||||
|
t1 := tagMiddleware("t1\n")
|
||||||
|
t2 := tagMiddleware("t2\n")
|
||||||
|
t3 := tagMiddleware("t3\n")
|
||||||
|
|
||||||
|
chained := NewChain(t1, t2, t3).Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chained.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Body.String() != "t1\nt2\nt3\napp\n" {
|
||||||
|
t.Error("Then does not order handlers correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
||||||
|
chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||||
|
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||||
|
|
||||||
|
if len(chain.constructors) != 2 {
|
||||||
|
t.Error("chain should have 2 constructors")
|
||||||
|
}
|
||||||
|
if len(newChain.constructors) != 4 {
|
||||||
|
t.Error("newChain should have 4 constructors")
|
||||||
|
}
|
||||||
|
|
||||||
|
chained := newChain.Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chained.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
|
||||||
|
t.Error("Append does not add handlers correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendRespectsImmutability(t *testing.T) {
|
||||||
|
chain := NewChain(tagMiddleware(""))
|
||||||
|
newChain := chain.Append(tagMiddleware(""))
|
||||||
|
|
||||||
|
if &chain.constructors[0] == &newChain.constructors[0] {
|
||||||
|
t.Error("Apppend does not respect immutability")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
||||||
|
chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||||
|
chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||||
|
newChain := chain1.Extend(chain2)
|
||||||
|
|
||||||
|
if len(chain1.constructors) != 2 {
|
||||||
|
t.Error("chain1 should contain 2 constructors")
|
||||||
|
}
|
||||||
|
if len(chain2.constructors) != 2 {
|
||||||
|
t.Error("chain2 should contain 2 constructors")
|
||||||
|
}
|
||||||
|
if len(newChain.constructors) != 4 {
|
||||||
|
t.Error("newChain should contain 4 constructors")
|
||||||
|
}
|
||||||
|
|
||||||
|
chained := newChain.Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chained.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
|
||||||
|
t.Error("Extend does not add handlers in correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendRespectsImmutability(t *testing.T) {
|
||||||
|
chain := NewChain(tagMiddleware(""))
|
||||||
|
newChain := chain.Extend(NewChain(tagMiddleware("")))
|
||||||
|
|
||||||
|
if &chain.constructors[0] == &newChain.constructors[0] {
|
||||||
|
t.Error("Extend does not respect immutability")
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
// Package middleware provides a standard set of middleware implementations for pomerium.
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -14,8 +15,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetHeaders ensures that every response includes some basic security headers
|
// SetHeadersOld ensures that every response includes some basic security headers
|
||||||
func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler {
|
func SetHeadersOld(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
for key, val := range securityHeaders {
|
for key, val := range securityHeaders {
|
||||||
rw.Header().Set(key, val)
|
rw.Header().Set(key, val)
|
||||||
|
@ -24,6 +25,18 @@ func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHeaders ensures that every response includes some basic security headers
|
||||||
|
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
for key, val := range securityHeaders {
|
||||||
|
rw.Header().Set(key, val)
|
||||||
|
}
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithMethods writes an error response if the method of the request is not included.
|
// WithMethods writes an error response if the method of the request is not included.
|
||||||
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
||||||
methodMap := make(map[string]struct{}, len(methods))
|
methodMap := make(map[string]struct{}, len(methods))
|
||||||
|
@ -116,14 +129,17 @@ func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateHost ensures that each request's host is valid
|
// ValidateHost ensures that each request's host is valid
|
||||||
func ValidateHost(h http.Handler, mux map[string]*http.Handler) http.Handler {
|
func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return func(next http.Handler) http.Handler {
|
||||||
if _, ok := mux[req.Host]; !ok {
|
|
||||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
if _, ok := mux[req.Host]; !ok {
|
||||||
}
|
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||||
h.ServeHTTP(rw, req)
|
return
|
||||||
})
|
}
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireHTTPS reroutes a HTTP request to HTTPS
|
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
@ -16,8 +16,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const loggingUserHeader = "SSO-Authenticated-User"
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
//ErrUserNotAuthorized is set when user is not authorized to access a resource
|
//ErrUserNotAuthorized is set when user is not authorized to access a resource
|
||||||
ErrUserNotAuthorized = errors.New("user not authorized")
|
ErrUserNotAuthorized = errors.New("user not authorized")
|
||||||
|
@ -45,92 +43,80 @@ func (p *Proxy) Handler() http.Handler {
|
||||||
var handler http.Handler = mux
|
var handler http.Handler = mux
|
||||||
// todo(bdd) : investigate if setting non-overridable headers makes sense
|
// todo(bdd) : investigate if setting non-overridable headers makes sense
|
||||||
// handler = p.setResponseHeaderOverrides(handler)
|
// handler = p.setResponseHeaderOverrides(handler)
|
||||||
handler = middleware.SetHeaders(handler, securityHeaders)
|
|
||||||
handler = middleware.ValidateHost(handler, p.mux)
|
|
||||||
handler = middleware.RequireHTTPS(handler)
|
|
||||||
handler = log.NewLoggingHandler(handler)
|
|
||||||
|
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
// Middleware chain
|
||||||
|
c := middleware.NewChain()
|
||||||
|
c = c.Append(log.NewHandler(log.Logger))
|
||||||
|
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||||
|
log.FromRequest(r).Info().
|
||||||
|
Str("method", r.Method).
|
||||||
|
Str("url", r.URL.String()).
|
||||||
|
Int("status", status).
|
||||||
|
Int("size", size).
|
||||||
|
Dur("duration", duration).
|
||||||
|
Str("pomerium-user", r.Header.Get(HeaderUserID)).
|
||||||
|
Str("pomerium-email", r.Header.Get(HeaderEmail)).
|
||||||
|
Msg("request")
|
||||||
|
}))
|
||||||
|
c = c.Append(middleware.SetHeaders(securityHeaders))
|
||||||
|
c = c.Append(middleware.RequireHTTPS)
|
||||||
|
c = c.Append(log.ForwardedAddrHandler("fwd_ip"))
|
||||||
|
c = c.Append(log.RemoteAddrHandler("ip"))
|
||||||
|
c = c.Append(log.UserAgentHandler("user_agent"))
|
||||||
|
c = c.Append(log.RefererHandler("referer"))
|
||||||
|
c = c.Append(log.RequestIDHandler("req_id", "Request-Id"))
|
||||||
|
c = c.Append(middleware.ValidateHost(p.mux))
|
||||||
|
h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Skip host validation for /ping requests because they hit the LB directly.
|
// Skip host validation for /ping requests because they hit the LB directly.
|
||||||
if req.URL.Path == "/ping" {
|
if r.URL.Path == "/ping" {
|
||||||
p.PingPage(rw, req)
|
p.PingPage(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(w, r)
|
||||||
})
|
}))
|
||||||
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
||||||
func (p *Proxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) {
|
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Favicon will proxy the request as usual if the user is already authenticated
|
// Favicon will proxy the request as usual if the user is already authenticated
|
||||||
// but responds with a 404 otherwise, to avoid spurious and confusing
|
// but responds with a 404 otherwise, to avoid spurious and confusing
|
||||||
// authentication attempts when a browser automatically requests the favicon on
|
// authentication attempts when a browser automatically requests the favicon on
|
||||||
// an error page.
|
// an error page.
|
||||||
func (p *Proxy) Favicon(rw http.ResponseWriter, req *http.Request) {
|
func (p *Proxy) Favicon(w http.ResponseWriter, r *http.Request) {
|
||||||
err := p.Authenticate(rw, req)
|
err := p.Authenticate(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rw.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.Proxy(rw, req)
|
p.Proxy(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PingPage send back a 200 OK response.
|
// PingPage send back a 200 OK response.
|
||||||
func (p *Proxy) PingPage(rw http.ResponseWriter, _ *http.Request) {
|
func (p *Proxy) PingPage(w http.ResponseWriter, _ *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(rw, "OK")
|
fmt.Fprintf(w, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOut redirects the request to the sign out url.
|
// SignOut redirects the request to the sign out url.
|
||||||
func (p *Proxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
|
|
||||||
redirectURL := &url.URL{
|
redirectURL := &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: req.Host,
|
Host: r.Host,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
}
|
}
|
||||||
fullURL := p.authenticateClient.GetSignOutURL(redirectURL)
|
fullURL := p.authenticateClient.GetSignOutURL(redirectURL)
|
||||||
http.Redirect(rw, req, fullURL.String(), http.StatusFound)
|
http.Redirect(w, r, fullURL.String(), http.StatusFound)
|
||||||
}
|
|
||||||
|
|
||||||
// XHRError returns a simple error response with an error message to the application if the request is an XML request
|
|
||||||
func (p *Proxy) XHRError(rw http.ResponseWriter, req *http.Request, code int, err error) {
|
|
||||||
jsonError := struct {
|
|
||||||
Error error `json:"error"`
|
|
||||||
}{
|
|
||||||
Error: err,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(jsonError)
|
|
||||||
if err != nil {
|
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requestLog := log.WithRequest(req, "proxy.ErrorPage")
|
|
||||||
requestLog.Error().Err(err).Int("http-status", code).Msg("proxy.XHRError")
|
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
|
||||||
rw.WriteHeader(code)
|
|
||||||
rw.Write(jsonBytes)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorPage renders an error page with a given status code, title, and message.
|
// ErrorPage renders an error page with a given status code, title, and message.
|
||||||
func (p *Proxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) {
|
func (p *Proxy) ErrorPage(w http.ResponseWriter, r *http.Request, code int, title string, message string) {
|
||||||
if p.isXHR(req) {
|
w.WriteHeader(code)
|
||||||
p.XHRError(rw, req, code, errors.New(message))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requestLog := log.WithRequest(req, "proxy.ErrorPage")
|
|
||||||
requestLog.Info().
|
|
||||||
Str("page-title", title).
|
|
||||||
Str("page-message", message).
|
|
||||||
Msg("proxy.ErrorPage")
|
|
||||||
|
|
||||||
rw.WriteHeader(code)
|
|
||||||
t := struct {
|
t := struct {
|
||||||
Code int
|
Code int
|
||||||
Title string
|
Title string
|
||||||
|
@ -142,223 +128,202 @@ func (p *Proxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, t
|
||||||
Message: message,
|
Message: message,
|
||||||
Version: version.FullVersion(),
|
Version: version.FullVersion(),
|
||||||
}
|
}
|
||||||
p.templates.ExecuteTemplate(rw, "error.html", t)
|
p.templates.ExecuteTemplate(w, "error.html", t)
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Proxy) isXHR(req *http.Request) bool {
|
|
||||||
return req.Header.Get("X-Requested-With") == "XMLHttpRequest"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthStart begins the authentication flow, encrypting the redirect url
|
// OAuthStart begins the authentication flow, encrypting the redirect url
|
||||||
// in a request to the provider's sign in endpoint.
|
// in a request to the provider's sign in endpoint.
|
||||||
func (p *Proxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
// The proxy redirects to the authenticator, and provides it with redirectURI (which points
|
// The proxy redirects to the authenticator, and provides it with redirectURI (which points
|
||||||
// back to the sso proxy).
|
// back to the sso proxy).
|
||||||
requestLog := log.WithRequest(req, "proxy.OAuthStart")
|
requestURI := r.URL.String()
|
||||||
|
callbackURL := p.GetRedirectURL(r.Host)
|
||||||
if p.isXHR(req) {
|
|
||||||
e := errors.New("cannot continue oauth flow on xhr")
|
|
||||||
requestLog.Error().Err(e).Msg("isXHR")
|
|
||||||
p.XHRError(rw, req, http.StatusUnauthorized, e)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requestURI := req.URL.String()
|
|
||||||
callbackURL := p.GetRedirectURL(req.Host)
|
|
||||||
|
|
||||||
// generate nonce
|
|
||||||
key := cryptutil.GenerateKey()
|
|
||||||
|
|
||||||
// state prevents cross site forgery and maintain state across the client and server
|
// state prevents cross site forgery and maintain state across the client and server
|
||||||
state := &StateParameter{
|
state := &StateParameter{
|
||||||
SessionID: fmt.Sprintf("%x", key), // nonce
|
SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()), // nonce
|
||||||
RedirectURI: requestURI, // where to redirect the user back to
|
RedirectURI: requestURI, // where to redirect the user back to
|
||||||
}
|
}
|
||||||
|
|
||||||
// we encrypt this value to be opaque the browser cookie
|
// we encrypt this value to be opaque the browser cookie
|
||||||
// this value will be unique since we always use a randomized nonce as part of marshaling
|
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||||
encryptedCSRF, err := p.cipher.Marshal(state)
|
encryptedCSRF, err := p.cipher.Marshal(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("failed to marshal csrf")
|
log.FromRequest(r).Error().Err(err).Msg("failed to marshal csrf")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.csrfStore.SetCSRF(rw, req, encryptedCSRF)
|
p.csrfStore.SetCSRF(w, r, encryptedCSRF)
|
||||||
|
|
||||||
// we encrypt this value to be opaque the uri query value
|
// we encrypt this value to be opaque the uri query value
|
||||||
// this value will be unique since we always use a randomized nonce as part of marshaling
|
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||||
encryptedState, err := p.cipher.Marshal(state)
|
encryptedState, err := p.cipher.Marshal(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("failed to encrypt cookie")
|
log.FromRequest(r).Error().Err(err).Msg("failed to encrypt cookie")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
signinURL := p.authenticateClient.GetSignInURL(callbackURL, encryptedState)
|
signinURL := p.authenticateClient.GetSignInURL(callbackURL, encryptedState)
|
||||||
requestLog.Info().Msg("redirecting to begin auth flow")
|
log.FromRequest(r).Info().Msg("redirecting to begin auth flow")
|
||||||
http.Redirect(rw, req, signinURL.String(), http.StatusFound)
|
http.Redirect(w, r, signinURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthCallback validates the cookie sent back from the provider, then validates
|
// OAuthCallback validates the cookie sent back from the provider, then validates
|
||||||
// the user information, and if authorized, redirects the user back to the original
|
// the user information, and if authorized, redirects the user back to the original
|
||||||
// application.
|
// application.
|
||||||
func (p *Proxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
// We receive the callback from the SSO Authenticator. This request will either contain an
|
// We receive the callback from the SSO Authenticator. This request will either contain an
|
||||||
// error, or it will contain a `code`; the code can be used to fetch an access token, and
|
// error, or it will contain a `code`; the code can be used to fetch an access token, and
|
||||||
// other metadata, from the authenticator.
|
// other metadata, from the authenticator.
|
||||||
requestLog := log.WithRequest(req, "proxy.OAuthCallback")
|
|
||||||
// finish the oauth cycle
|
// finish the oauth cycle
|
||||||
err := req.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("failed parsing request form")
|
log.FromRequest(r).Error().Err(err).Msg("failed parsing request form")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorString := req.Form.Get("error")
|
errorString := r.Form.Get("error")
|
||||||
if errorString != "" {
|
if errorString != "" {
|
||||||
p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorString)
|
p.ErrorPage(w, r, http.StatusForbidden, "Permission Denied", errorString)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We begin the process of redeeming the code for an access token.
|
// We begin the process of redeeming the code for an access token.
|
||||||
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
|
session, err := p.redeemCode(r.Host, r.Form.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("error redeeming authorization code")
|
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedState := req.Form.Get("state")
|
encryptedState := r.Form.Get("state")
|
||||||
stateParameter := &StateParameter{}
|
stateParameter := &StateParameter{}
|
||||||
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("could not unmarshal state")
|
log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := p.csrfStore.GetCSRF(req)
|
c, err := p.csrfStore.GetCSRF(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("failed parsing csrf cookie")
|
log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie")
|
||||||
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", err.Error())
|
p.ErrorPage(w, r, http.StatusBadRequest, "Bad Request", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.csrfStore.ClearCSRF(rw, req)
|
p.csrfStore.ClearCSRF(w, r)
|
||||||
|
|
||||||
encryptedCSRF := c.Value
|
encryptedCSRF := c.Value
|
||||||
csrfParameter := &StateParameter{}
|
csrfParameter := &StateParameter{}
|
||||||
err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
|
err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Msg("couldn't unmarshal CSRF")
|
log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if encryptedState == encryptedCSRF {
|
if encryptedState == encryptedCSRF {
|
||||||
requestLog.Error().Msg("encrypted state and CSRF should not be equal")
|
log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
|
||||||
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
|
p.ErrorPage(w, r, http.StatusBadRequest, "Bad Request", "Bad Request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
||||||
requestLog.Error().Msg("state and CSRF should be equal")
|
log.FromRequest(r).Error().Msg("state and CSRF should be equal")
|
||||||
p.ErrorPage(rw, req, http.StatusBadRequest, "Bad Request", "Bad Request")
|
p.ErrorPage(w, r, http.StatusBadRequest, "Bad Request", "Bad Request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We store the session in a cookie and redirect the user back to the application
|
// We store the session in a cookie and redirect the user back to the application
|
||||||
err = p.sessionStore.SaveSession(rw, req, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Msg("error saving session")
|
log.FromRequest(r).Error().Msg("error saving session")
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", "Internal Error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is the redirect back to the original requested application
|
// This is the redirect back to the original requested application
|
||||||
http.Redirect(rw, req, stateParameter.RedirectURI, http.StatusFound)
|
http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticateOnly calls the Authenticate handler.
|
// AuthenticateOnly calls the Authenticate handler.
|
||||||
func (p *Proxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
|
func (p *Proxy) AuthenticateOnly(w http.ResponseWriter, r *http.Request) {
|
||||||
err := p.Authenticate(rw, req)
|
err := p.Authenticate(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
|
http.Error(w, "unauthorized request", http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy authenticates a request, either proxying the request if it is authenticated, or starting the authentication process if not.
|
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
||||||
func (p *Proxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
// or starting the authentication process if not.
|
||||||
|
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
// Attempts to validate the user and their cookie.
|
// Attempts to validate the user and their cookie.
|
||||||
// start := time.Now()
|
|
||||||
var err error
|
var err error
|
||||||
err = p.Authenticate(rw, req)
|
err = p.Authenticate(w, r)
|
||||||
// If the authentication is not successful we proceed to start the OAuth Flow with
|
// If the authentication is not successful we proceed to start the OAuth Flow with
|
||||||
// OAuthStart. If authentication is successful, we proceed to proxy to the configured
|
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
||||||
// upstream.
|
|
||||||
requestLog := log.WithRequest(req, "proxy.Proxy")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case http.ErrNoCookie:
|
case http.ErrNoCookie:
|
||||||
// No cookie is set, start the oauth flow
|
// No cookie is set, start the oauth flow
|
||||||
p.OAuthStart(rw, req)
|
p.OAuthStart(w, r)
|
||||||
return
|
return
|
||||||
case ErrUserNotAuthorized:
|
case ErrUserNotAuthorized:
|
||||||
// We know the user is not authorized for the request, we show them a forbidden page
|
// We know the user is not authorized for the request, we show them a forbidden page
|
||||||
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
|
p.ErrorPage(w, r, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
|
||||||
return
|
return
|
||||||
case sessions.ErrLifetimeExpired:
|
case sessions.ErrLifetimeExpired:
|
||||||
// User's lifetime expired, we trigger the start of the oauth flow
|
// User's lifetime expired, we trigger the start of the oauth flow
|
||||||
p.OAuthStart(rw, req)
|
p.OAuthStart(w, r)
|
||||||
return
|
return
|
||||||
case sessions.ErrInvalidSession:
|
case sessions.ErrInvalidSession:
|
||||||
// The user session is invalid and we can't decode it.
|
// The user session is invalid and we can't decode it.
|
||||||
// This can happen for a variety of reasons but the most common non-malicious
|
// This can happen for a variety of reasons but the most common non-malicious
|
||||||
// case occurs when the session encoding schema changes. We manage this ux
|
// case occurs when the session encoding schema changes. We manage this ux
|
||||||
// by triggering the start of the oauth flow.
|
// by triggering the start of the oauth flow.
|
||||||
p.OAuthStart(rw, req)
|
p.OAuthStart(w, r)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
requestLog.Error().Err(err).Msg("unknown error")
|
log.FromRequest(r).Error().Err(err).Msg("unknown error")
|
||||||
// We don't know exactly what happened, but authenticating the user failed, show an error
|
// We don't know exactly what happened, but authenticating the user failed, show an error
|
||||||
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "An unexpected error occurred")
|
p.ErrorPage(w, r, http.StatusInternalServerError, "Internal Error", "An unexpected error occurred")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have validated the users request and now proxy their request to the provided upstream.
|
// We have validated the users request and now proxy their request to the provided upstream.
|
||||||
route, ok := p.router(req)
|
route, ok := p.router(r)
|
||||||
if !ok {
|
if !ok {
|
||||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
httputil.ErrorResponse(w, r, "unknown route to proxy", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
route.ServeHTTP(rw, req)
|
route.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
|
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
|
||||||
// clearing the session cookie if it's invalid and returning an error if necessary..
|
// clearing the session cookie if it's invalid and returning an error if necessary..
|
||||||
func (p *Proxy) Authenticate(rw http.ResponseWriter, req *http.Request) (err error) {
|
func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error) {
|
||||||
|
|
||||||
// Clear the session cookie if anything goes wrong.
|
// Clear the session cookie if anything goes wrong.
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(w, r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
requestLog := log.WithRequest(req, "proxy.Authenticate")
|
|
||||||
|
|
||||||
session, err := p.sessionStore.LoadSession(req)
|
session, err := p.sessionStore.LoadSession(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We loaded a cookie but it wasn't valid, clear it, and reject the request
|
// We loaded a cookie but it wasn't valid, clear it, and reject the request
|
||||||
requestLog.Error().Err(err).Msg("error authenticating user")
|
log.FromRequest(r).Error().Err(err).Msg("error authenticating user")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lifetime period is the entire duration in which the session is valid.
|
// Lifetime period is the entire duration in which the session is valid.
|
||||||
// This should be set to something like 14 to 30 days.
|
// This should be set to something like 14 to 30 days.
|
||||||
if session.LifetimePeriodExpired() {
|
if session.LifetimePeriodExpired() {
|
||||||
requestLog.Warn().Str("user", session.Email).Msg("session lifetime has expired")
|
log.FromRequest(r).Warn().Str("user", session.Email).Msg("session lifetime has expired")
|
||||||
return sessions.ErrLifetimeExpired
|
return sessions.ErrLifetimeExpired
|
||||||
} else if session.RefreshPeriodExpired() {
|
} else if session.RefreshPeriodExpired() {
|
||||||
// Refresh period is the period in which the access token is valid. This is ultimately
|
// Refresh period is the period in which the access token is valid. This is ultimately
|
||||||
|
@ -368,24 +333,24 @@ func (p *Proxy) Authenticate(rw http.ResponseWriter, req *http.Request) (err err
|
||||||
// We failed to refresh the session successfully
|
// We failed to refresh the session successfully
|
||||||
// clear the cookie and reject the request
|
// clear the cookie and reject the request
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestLog.Error().Err(err).Str("user", session.Email).Msg("refreshing session failed")
|
log.FromRequest(r).Error().Err(err).Str("user", session.Email).Msg("refreshing session failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// User is not authorized after refresh
|
// User is not authorized after refresh
|
||||||
// clear the cookie and reject the request
|
// clear the cookie and reject the request
|
||||||
requestLog.Error().Str("user", session.Email).Msg("not authorized after refreshing session")
|
log.FromRequest(r).Error().Str("user", session.Email).Msg("not authorized after refreshing session")
|
||||||
return ErrUserNotAuthorized
|
return ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.sessionStore.SaveSession(rw, req, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We refreshed the session successfully, but failed to save it.
|
// We refreshed the session successfully, but failed to save it.
|
||||||
//
|
//
|
||||||
// This could be from failing to encode the session properly.
|
// This could be from failing to encode the session properly.
|
||||||
// But, we clear the session cookie and reject the request!
|
// But, we clear the session cookie and reject the request!
|
||||||
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save refresh session")
|
log.FromRequest(r).Error().Err(err).Str("user", session.Email).Msg("could not save refresh session")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if session.ValidationPeriodExpired() {
|
} else if session.ValidationPeriodExpired() {
|
||||||
|
@ -398,38 +363,23 @@ func (p *Proxy) Authenticate(rw http.ResponseWriter, req *http.Request) (err err
|
||||||
// This user is now no longer authorized, or we failed to
|
// This user is now no longer authorized, or we failed to
|
||||||
// validate the user.
|
// validate the user.
|
||||||
// Clear the cookie and reject the request
|
// Clear the cookie and reject the request
|
||||||
requestLog.Error().Str("user", session.Email).Msg("no longer authorized after validation period")
|
log.FromRequest(r).Error().Str("user", session.Email).Msg("no longer authorized after validation period")
|
||||||
return ErrUserNotAuthorized
|
return ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.sessionStore.SaveSession(rw, req, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We validated the session successfully, but failed to save it.
|
// We validated the session successfully, but failed to save it.
|
||||||
|
|
||||||
// This could be from failing to encode the session properly.
|
// This could be from failing to encode the session properly.
|
||||||
// But, we clear the session cookie and reject the request!
|
// But, we clear the session cookie and reject the request!
|
||||||
requestLog.Error().Err(err).Str("user", session.Email).Msg("could not save validated session")
|
log.FromRequest(r).Error().Err(err).Str("user", session.Email).Msg("could not save validated session")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if !p.EmailValidator(session.Email) {
|
r.Header.Set(HeaderUserID, session.User)
|
||||||
// requestLog.Error().Str("user", session.Email).Msg("email failed to validate, unauthorized")
|
r.Header.Set(HeaderEmail, session.Email)
|
||||||
// return ErrUserNotAuthorized
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// todo(bdd) : handled by authorize package
|
|
||||||
|
|
||||||
req.Header.Set("X-Forwarded-User", session.User)
|
|
||||||
|
|
||||||
if p.PassAccessToken && session.AccessToken != "" {
|
|
||||||
req.Header.Set("X-Forwarded-Access-Token", session.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("X-Forwarded-Email", session.Email)
|
|
||||||
|
|
||||||
// stash authenticated user so that it can be logged later (see func logRequest)
|
|
||||||
rw.Header().Set(loggingUserHeader, session.Email)
|
|
||||||
|
|
||||||
// This user has been OK'd. Allow the request!
|
// This user has been OK'd. Allow the request!
|
||||||
return nil
|
return nil
|
||||||
|
@ -442,13 +392,12 @@ func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||||
|
|
||||||
// router attempts to find a route for a request. If a route is successfully matched,
|
// router attempts to find a route for a request. If a route is successfully matched,
|
||||||
// it returns the route information and a bool value of `true`. If a route can not be matched,
|
// it returns the route information and a bool value of `true`. If a route can not be matched,
|
||||||
//a nil value for the route and false bool value is returned.
|
// a nil value for the route and false bool value is returned.
|
||||||
func (p *Proxy) router(req *http.Request) (http.Handler, bool) {
|
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
|
||||||
route, ok := p.mux[req.Host]
|
route, ok := p.mux[r.Host]
|
||||||
if ok {
|
if ok {
|
||||||
return *route, true
|
return *route, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,12 +20,24 @@ import (
|
||||||
"github.com/pomerium/pomerium/proxy/authenticator"
|
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// HeaderJWT is the header key for pomerium proxy's JWT signature.
|
||||||
|
HeaderJWT = "x-pomerium-jwt-assertion"
|
||||||
|
// HeaderUserID represents the header key for the user that is passed to the client.
|
||||||
|
HeaderUserID = "x-pomerium-authenticated-user-id"
|
||||||
|
// HeaderEmail represents the header key for the email that is passed to the client.
|
||||||
|
HeaderEmail = "x-pomerium-authenticated-user-email"
|
||||||
|
)
|
||||||
|
|
||||||
// Options represents the configuration options for the proxy service.
|
// Options represents the configuration options for the proxy service.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// AuthenticateServiceURL specifies the url to the pomerium authenticate http service.
|
// AuthenticateServiceURL specifies the url to the pomerium authenticate http service.
|
||||||
AuthenticateServiceURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
AuthenticateServiceURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
||||||
|
|
||||||
// todo(bdd) : replace with certificate based mTLS
|
// SigningKey is a base64 encoded private key used to add a JWT-signature to proxied requests.
|
||||||
|
// See : https://www.pomerium.io/guide/signed-headers.html
|
||||||
|
SigningKey string `envconfig:"SIGNING_KEY"`
|
||||||
|
// SharedKey is a 32 byte random key used to authenticate access between services.
|
||||||
SharedKey string `envconfig:"SHARED_SECRET"`
|
SharedKey string `envconfig:"SHARED_SECRET"`
|
||||||
|
|
||||||
DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT"`
|
DefaultUpstreamTimeout time.Duration `envconfig:"DEFAULT_UPSTREAM_TIMEOUT"`
|
||||||
|
@ -101,6 +113,12 @@ func (o *Options) Validate() error {
|
||||||
if len(decodedCookieSecret) != 32 {
|
if len(decodedCookieSecret) != 32 {
|
||||||
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
|
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
|
||||||
}
|
}
|
||||||
|
if len(o.SigningKey) != 0 {
|
||||||
|
_, err := base64.StdEncoding.DecodeString(o.SigningKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("signing key is invalid base64: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,7 +133,7 @@ type Proxy struct {
|
||||||
csrfStore sessions.CSRFStore
|
csrfStore sessions.CSRFStore
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
|
|
||||||
redirectURL *url.URL // the url to receive requests at
|
redirectURL *url.URL
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
mux map[string]*http.Handler
|
mux map[string]*http.Handler
|
||||||
}
|
}
|
||||||
|
@ -135,7 +153,6 @@ func New(opts *Options) (*Proxy, error) {
|
||||||
if err := opts.Validate(); err != nil {
|
if err := opts.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// error explicitly handled by validate
|
// error explicitly handled by validate
|
||||||
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
cipher, err := cryptutil.NewCipher(decodedSecret)
|
cipher, err := cryptutil.NewCipher(decodedSecret)
|
||||||
|
@ -183,7 +200,10 @@ func New(opts *Options) (*Proxy, error) {
|
||||||
fromURL, _ := urlParse(from)
|
fromURL, _ := urlParse(from)
|
||||||
toURL, _ := urlParse(to)
|
toURL, _ := urlParse(to)
|
||||||
reverseProxy := NewReverseProxy(toURL)
|
reverseProxy := NewReverseProxy(toURL)
|
||||||
handler := NewReverseProxyHandler(opts, reverseProxy, toURL.String())
|
handler, err := NewReverseProxyHandler(opts, reverseProxy, fromURL.Host, toURL.Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
p.Handle(fromURL.Host, handler)
|
p.Handle(fromURL.Host, handler)
|
||||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New : route")
|
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New : route")
|
||||||
}
|
}
|
||||||
|
@ -196,6 +216,7 @@ type UpstreamProxy struct {
|
||||||
name string
|
name string
|
||||||
cookieName string
|
cookieName string
|
||||||
handler http.Handler
|
handler http.Handler
|
||||||
|
signer cryptutil.JWTSigner
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultUpstreamTransport = &http.Transport{
|
var defaultUpstreamTransport = &http.Transport{
|
||||||
|
@ -211,8 +232,8 @@ var defaultUpstreamTransport = &http.Transport{
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteSSOCookieHeader deletes the session cookie from the request header string.
|
// deleteUpstreamCookies deletes the session cookie from the request header string.
|
||||||
func deleteSSOCookieHeader(req *http.Request, cookieName string) {
|
func deleteUpstreamCookies(req *http.Request, cookieName string) {
|
||||||
headers := []string{}
|
headers := []string{}
|
||||||
for _, cookie := range req.Cookies() {
|
for _, cookie := range req.Cookies() {
|
||||||
if cookie.Name != cookieName {
|
if cookie.Name != cookieName {
|
||||||
|
@ -222,10 +243,23 @@ func deleteSSOCookieHeader(req *http.Request, cookieName string) {
|
||||||
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// signRequest signs a g
|
||||||
|
func (u *UpstreamProxy) signRequest(req *http.Request) {
|
||||||
|
if u.signer != nil {
|
||||||
|
jwt, err := u.signer.SignJWT(req.Header.Get(HeaderUserID), req.Header.Get(HeaderEmail))
|
||||||
|
if err == nil {
|
||||||
|
req.Header.Set(HeaderJWT, jwt)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ServeHTTP signs the http request and deletes cookie headers
|
// ServeHTTP signs the http request and deletes cookie headers
|
||||||
// before calling the upstream's ServeHTTP function.
|
// before calling the upstream's ServeHTTP function.
|
||||||
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
deleteSSOCookieHeader(r, u.cookieName)
|
deleteUpstreamCookies(r, u.cookieName)
|
||||||
|
u.signRequest(r)
|
||||||
u.handler.ServeHTTP(w, r)
|
u.handler.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,6 +271,8 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||||
|
|
||||||
director := proxy.Director
|
director := proxy.Director
|
||||||
proxy.Director = func(req *http.Request) {
|
proxy.Director = func(req *http.Request) {
|
||||||
|
// Identifies the originating IP addresses of a client connecting to
|
||||||
|
// a web server through an HTTP proxy or a load balancer.
|
||||||
req.Header.Add("X-Forwarded-Host", req.Host)
|
req.Header.Add("X-Forwarded-Host", req.Host)
|
||||||
director(req)
|
director(req)
|
||||||
req.Host = to.Host
|
req.Host = to.Host
|
||||||
|
@ -245,16 +281,26 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReverseProxyHandler applies handler specific options to a given route.
|
// NewReverseProxyHandler applies handler specific options to a given route.
|
||||||
func NewReverseProxyHandler(opts *Options, reverseProxy *httputil.ReverseProxy, serviceName string) http.Handler {
|
func NewReverseProxyHandler(opts *Options, reverseProxy *httputil.ReverseProxy, from, to string) (http.Handler, error) {
|
||||||
upstreamProxy := &UpstreamProxy{
|
up := &UpstreamProxy{
|
||||||
name: serviceName,
|
name: to,
|
||||||
handler: reverseProxy,
|
handler: reverseProxy,
|
||||||
cookieName: opts.CookieName,
|
cookieName: opts.CookieName,
|
||||||
}
|
}
|
||||||
|
if len(opts.SigningKey) != 0 {
|
||||||
|
decodedSigningKey, err := base64.StdEncoding.DecodeString(opts.SigningKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signer, err := cryptutil.NewES256Signer(decodedSigningKey, from)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
up.signer = signer
|
||||||
|
}
|
||||||
timeout := opts.DefaultUpstreamTimeout
|
timeout := opts.DefaultUpstreamTimeout
|
||||||
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", serviceName, timeout)
|
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", to, timeout)
|
||||||
return http.TimeoutHandler(upstreamProxy, timeout, timeoutMsg)
|
return http.TimeoutHandler(up, timeout, timeoutMsg), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// urlParse adds a scheme if none-exists, addressesing a quirk in how
|
// urlParse adds a scheme if none-exists, addressesing a quirk in how
|
||||||
|
|
|
@ -110,7 +110,10 @@ func TestNewReverseProxyHandler(t *testing.T) {
|
||||||
|
|
||||||
proxyHandler := NewReverseProxy(proxyURL)
|
proxyHandler := NewReverseProxy(proxyURL)
|
||||||
opts := defaultOptions
|
opts := defaultOptions
|
||||||
handle := NewReverseProxyHandler(opts, proxyHandler, "name")
|
handle, err := NewReverseProxyHandler(opts, proxyHandler, "from", "to")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("got %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
frontend := httptest.NewServer(handle)
|
frontend := httptest.NewServer(handle)
|
||||||
|
|
||||||
|
@ -152,7 +155,8 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
||||||
shortCookieLength := testOptions()
|
shortCookieLength := testOptions()
|
||||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
||||||
|
invalidSignKey := testOptions()
|
||||||
|
invalidSignKey.SigningKey = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
||||||
badSharedKey := testOptions()
|
badSharedKey := testOptions()
|
||||||
badSharedKey.SharedKey = ""
|
badSharedKey.SharedKey = ""
|
||||||
|
|
||||||
|
@ -162,7 +166,6 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"good - minimum options", good, false},
|
{"good - minimum options", good, false},
|
||||||
|
|
||||||
{"nil options", &Options{}, true},
|
{"nil options", &Options{}, true},
|
||||||
{"from route", badFromRoute, true},
|
{"from route", badFromRoute, true},
|
||||||
{"to route", badToRoute, true},
|
{"to route", badToRoute, true},
|
||||||
|
@ -172,6 +175,7 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
{"invalid cookie secret", invalidCookieSecret, true},
|
{"invalid cookie secret", invalidCookieSecret, true},
|
||||||
{"short cookie secret", shortCookieLength, true},
|
{"short cookie secret", shortCookieLength, true},
|
||||||
{"no shared secret", badSharedKey, true},
|
{"no shared secret", badSharedKey, true},
|
||||||
|
{"invalid signing key", invalidSignKey, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -187,6 +191,8 @@ func TestNew(t *testing.T) {
|
||||||
good := testOptions()
|
good := testOptions()
|
||||||
shortCookieLength := testOptions()
|
shortCookieLength := testOptions()
|
||||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
||||||
|
badRoutedProxy := testOptions()
|
||||||
|
badRoutedProxy.SigningKey = "YmFkIGtleQo="
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -197,9 +203,10 @@ func TestNew(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"good - minimum options", good, nil, true, 1, false},
|
{"good - minimum options", good, nil, true, 1, false},
|
||||||
{"bad - empty options", &Options{}, nil, false, 0, true},
|
{"empty options", &Options{}, nil, false, 0, true},
|
||||||
{"bad - nil options", nil, nil, false, 0, true},
|
{"nil options", nil, nil, false, 0, true},
|
||||||
{"bad - short secret/validate sanity check", shortCookieLength, nil, false, 0, true},
|
{"short secret/validate sanity check", shortCookieLength, nil, false, 0, true},
|
||||||
|
{"invalid ec key, valid base64 though", badRoutedProxy, nil, false, 0, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
17
scripts/generate_self_signed_signing_key.sh
Executable file
17
scripts/generate_self_signed_signing_key.sh
Executable file
|
@ -0,0 +1,17 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# See: https://cloud.google.com/iot/docs/how-tos/credentials/keys#generating_an_es256_key_with_a_self-signed_x509_certificate
|
||||||
|
# To generate an ES256 key with a self-signed X.509 certificate that expires far in the future, run the following commands:
|
||||||
|
|
||||||
|
openssl ecparam \
|
||||||
|
-genkey \
|
||||||
|
-name prime256v1 \
|
||||||
|
-noout \
|
||||||
|
-out ec_private.pem
|
||||||
|
|
||||||
|
openssl req \
|
||||||
|
-x509 \
|
||||||
|
-new \
|
||||||
|
-key ec_private.pem \
|
||||||
|
-days 1000000 \
|
||||||
|
-out ec_public.pem \
|
||||||
|
-subj "/CN=unused"
|
|
@ -1,11 +1,16 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# requires certbot
|
# requires acme.sh
|
||||||
certbot certonly --manual \
|
# see : https://github.com/Neilpang/acme.sh
|
||||||
--agree-tos \
|
# uncomment below to install
|
||||||
-d *.corp.example.com \
|
# curl https://get.acme.sh | sh
|
||||||
--preferred-challenges dns-01 \
|
|
||||||
--server https://acme-v02.api.letsencrypt.org/directory \
|
# assumes cloudflare, but many DNS providers are supported
|
||||||
--config-dir le/config \
|
|
||||||
--logs-dir le/work \
|
export CF_Key="x"
|
||||||
--work-dir le/work
|
export CF_Email="x@x.com"
|
||||||
|
|
||||||
|
$HOME/.acme.sh/acme.sh \
|
||||||
|
--issue \
|
||||||
|
-d '*.corp.beyondperimeter.com' \
|
||||||
|
--dns dns_cf
|
||||||
|
|
47
scripts/self-signed-sign-key.sh
Executable file
47
scripts/self-signed-sign-key.sh
Executable file
|
@ -0,0 +1,47 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Thank you @ https://medium.com/@benjamin.black/how-to-obtain-an-ecdsa-wildcard-certificate-from-lets-encrypt-be217c737cfe
|
||||||
|
# See also:
|
||||||
|
# https://cloud.google.com/iot/docs/how-tos/credentials/keys#generating_an_es256_key_with_a_self-signed_x509_certificate
|
||||||
|
# https://community.letsencrypt.org/t/ecc-certificates/46729
|
||||||
|
#
|
||||||
|
# Let’s Encrypt currently generates RSA certificates, but not yet ECDSA certificates.
|
||||||
|
# Support for generating ECDSA certificates is on the horizon, but is not here yet.
|
||||||
|
# However, Let’s Encrypt does support *signing* ECDSA certificates when presented with a
|
||||||
|
# Certificate Signing Request. So we can generate the appropriate CSR on the client,
|
||||||
|
# and send it to Let’s Encrypt using the --csr option of the certbot client for Let’s Encrypt to sign.
|
||||||
|
|
||||||
|
# The following generates a NIST P-256 (aka secp256r1 aka prime256v1) EC Key Pair
|
||||||
|
openssl ecparam \
|
||||||
|
-genkey \
|
||||||
|
-name prime256v1 \
|
||||||
|
-noout \
|
||||||
|
-out ec_private.pem
|
||||||
|
|
||||||
|
openssl req -x509 -new \
|
||||||
|
-key ec_private.pem \
|
||||||
|
-days 365 \
|
||||||
|
-out ec_public.pem \
|
||||||
|
-subj "/CN=unused"
|
||||||
|
|
||||||
|
openssl req -new \
|
||||||
|
-sha512 \
|
||||||
|
-key privkey.pem \
|
||||||
|
-nodes \
|
||||||
|
-subj "/CN=beyondperimeter.com" \
|
||||||
|
-reqexts SAN \
|
||||||
|
-extensions SAN \
|
||||||
|
-config <(cat /etc/ssl/openssl.cnf <(printf '[SAN]\nsubjectAltName=DNS:*.corp.beyondperimeter.com')) \
|
||||||
|
-out csr.pem \
|
||||||
|
-outform pem
|
||||||
|
|
||||||
|
openssl req -in csr.pem -noout -text
|
||||||
|
|
||||||
|
certbot certonly \
|
||||||
|
--preferred-challenges dns-01 \
|
||||||
|
--work-dir le/work \
|
||||||
|
--config-dir le/config \
|
||||||
|
--logs-dir le/logs \
|
||||||
|
--agree-tos \
|
||||||
|
--email bobbydesimone@gmail.com \
|
||||||
|
-d *.corp.beyondperimeter.com \
|
||||||
|
--csr csr.pem
|
Loading…
Add table
Add a link
Reference in a new issue