diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 849424924..c833e728a 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -3,7 +3,6 @@ package authenticate import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" "net/http" @@ -66,8 +65,8 @@ func (a *Authenticate) Mount(r *mux.Router) { v := r.PathPrefix("/.pomerium").Subrouter() c := cors.New(cors.Options{ AllowOriginRequestFunc: func(r *http.Request, _ string) bool { - options := a.options.Load() - err := middleware.ValidateRequestURL(r, options.SharedKey) + state := a.state.Load() + err := middleware.ValidateRequestURL(r, string(state.sharedSecret)) if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: origin blocked") } @@ -89,12 +88,6 @@ func (a *Authenticate) Mount(r *mux.Router) { wk := r.PathPrefix("/.well-known/pomerium").Subrouter() wk.Path("/jwks.json").Handler(httputil.HandlerFunc(a.jwks)).Methods(http.MethodGet) wk.Path("/").Handler(httputil.HandlerFunc(a.wellKnown)).Methods(http.MethodGet) - - // programmatic access api endpoint - api := r.PathPrefix("/api").Subrouter() - api.Use(func(h http.Handler) http.Handler { - return sessions.RetrieveSession(a.state.Load().sessionLoaders...)(h) - }) } // Well-Known Uniform Resource Identifiers (URIs) @@ -111,28 +104,12 @@ func (a *Authenticate) wellKnown(w http.ResponseWriter, r *http.Request) error { state.redirectURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(), state.redirectURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(), } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Content-Type-Options", "nosniff") - jBytes, err := json.Marshal(wellKnownURLS) - if err != nil { - return err - } - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, "%s", jBytes) + httputil.RenderJSON(w, http.StatusOK, wellKnownURLS) return nil } func (a *Authenticate) jwks(w http.ResponseWriter, r *http.Request) error { - state := a.state.Load() - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Content-Type-Options", "nosniff") - jBytes, err := json.Marshal(state.jwk) - if err != nil { - return err - } - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, "%s", jBytes) + httputil.RenderJSON(w, http.StatusOK, a.state.Load().jwk) return nil } @@ -151,12 +128,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return a.reauthenticateOrFail(w, r, err) } - if state.dataBrokerClient != nil { - _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID) - if err != nil { - log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker") - return a.reauthenticateOrFail(w, r, err) - } + if state.dataBrokerClient == nil { + return errors.New("authenticate: databroker client cannot be nil") + } + if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil { + log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker") + return a.reauthenticateOrFail(w, r, err) } next.ServeHTTP(w, r.WithContext(ctx)) @@ -179,11 +156,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { options := a.options.Load() state := a.state.Load() - sharedCipher, err := cryptutil.NewAEADCipherFromBase64(options.SharedKey) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { return httputil.NewError(http.StatusBadRequest, err) @@ -241,8 +213,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { return httputil.NewError(http.StatusBadRequest, err) } - // encrypt our route-based token JWT avoiding any accidental logging - encryptedJWT := cryptutil.Encrypt(sharedCipher, signedJWT, nil) + // encrypt our route-scoped JWT to avoid accidental logging of queryparams + encryptedJWT := cryptutil.Encrypt(a.state.Load().sharedCipher, signedJWT, nil) // base64 our encrypted payload for URL-friendlyness encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) @@ -413,7 +385,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } - err = a.saveSessionToDataBroker(r.Context(), &s, accessToken) + // save the session and access token to the databroker + err = a.saveSessionToDataBroker(ctx, &s, accessToken) if err != nil { return nil, httputil.NewError(http.StatusInternalServerError, err) } @@ -550,11 +523,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error { input["SignOutURL"] = "/.pomerium/sign_out" } - err = a.templates.ExecuteTemplate(w, "dashboard.html", input) - if err != nil { - log.Warn().Err(err).Interface("input", input).Msg("proxy: error rendering dashboard") - } - return nil + return a.templates.ExecuteTemplate(w, "dashboard.html", input) } func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState *sessions.State, accessToken *oauth2.Token) error { diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index d6882d4fd..542945f56 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -150,8 +150,11 @@ func TestAuthenticate_SignIn(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) + a := &Authenticate{ state: newAtomicAuthenticateState(&authenticateState{ + sharedCipher: sharedCipher, sessionStore: tt.session, redirectURL: uriParseHelper("https://some.example"), sharedEncoder: tt.encoder, @@ -566,7 +569,7 @@ func TestWellKnownEndpoint(t *testing.T) { rr := httptest.NewRecorder() h.ServeHTTP(rr, req) body := rr.Body.String() - expected := `{"jwks_uri":"https://auth.example.com/.well-known/pomerium/jwks.json","authentication_callback_endpoint":"https://auth.example.com/oauth2/callback"}` + expected := "{\"jwks_uri\":\"https://auth.example.com/.well-known/pomerium/jwks.json\",\"authentication_callback_endpoint\":\"https://auth.example.com/oauth2/callback\"}\n" assert.Equal(t, body, expected) } @@ -587,7 +590,7 @@ func TestJwksEndpoint(t *testing.T) { rr := httptest.NewRecorder() h.ServeHTTP(rr, req) body := rr.Body.String() - expected := `{"keys":[{"use":"sig","kty":"EC","kid":"5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c","crv":"P-256","alg":"ES256","x":"UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo","y":"KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ"}]}` + expected := "{\"keys\":[{\"use\":\"sig\",\"kty\":\"EC\",\"kid\":\"5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c\",\"crv\":\"P-256\",\"alg\":\"ES256\",\"x\":\"UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo\",\"y\":\"KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ\"}]}\n" assert.Equal(t, expected, body) } func TestAuthenticate_Dashboard(t *testing.T) { diff --git a/authenticate/state.go b/authenticate/state.go index 74d3ec199..d03ae781e 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -17,7 +17,6 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/header" - "github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" @@ -32,6 +31,10 @@ type authenticateState struct { // sharedEncoder is the encoder to use to serialize data to be consumed // by other services sharedEncoder encoding.MarshalUnmarshaler + // sharedSecret is the secret to encrypt and authenticate data shared between services + sharedSecret []byte + // sharedCipher is the cipher to use to encrypt/decrypt data shared between services + sharedCipher cipher.AEAD // cookieSecret is the secret to encrypt and authenticate session data cookieSecret []byte // cookieCipher is the cipher to use to encrypt/decrypt session data @@ -79,12 +82,15 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err return nil, err } + // shared cipher to encrypt data before passing data between services + state.sharedSecret, _ = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + state.sharedCipher, _ = cryptutil.NewAEADCipher(state.sharedSecret) + // private state encoder setup, used to encrypt oauth2 tokens state.cookieSecret, _ = base64.StdEncoding.DecodeString(cfg.Options.CookieSecret) state.cookieCipher, _ = cryptutil.NewAEADCipher(state.cookieSecret) state.encryptedEncoder = ecjson.New(state.cookieCipher) - qpStore := queryparam.NewStore(state.encryptedEncoder, urlutil.QueryProgrammaticToken) headerStore := header.NewStore(state.encryptedEncoder, httputil.AuthorizationTypePomerium) cookieStore, err := cookie.NewStore(func() cookie.Options { @@ -101,7 +107,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err } state.sessionStore = cookieStore - state.sessionLoaders = []sessions.SessionLoader{qpStore, headerStore, cookieStore} + state.sessionLoaders = []sessions.SessionLoader{headerStore, cookieStore} state.jwk = new(jose.JSONWebKeySet) if cfg.Options.SigningKey != "" { diff --git a/internal/httputil/errors.go b/internal/httputil/errors.go index e99b5b48d..9654bfad0 100644 --- a/internal/httputil/errors.go +++ b/internal/httputil/errors.go @@ -1,7 +1,6 @@ package httputil import ( - "encoding/json" "html/template" "net/http" @@ -68,10 +67,7 @@ func (e *HTTPError) ErrorResponse(w http.ResponseWriter, r *http.Request) { } if r.Header.Get("Accept") == "application/json" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } + RenderJSON(w, e.Status, response) return } w.Header().Set("Content-Type", "text/html; charset=UTF-8") diff --git a/internal/httputil/handlers.go b/internal/httputil/handlers.go index 14d0ce832..bd86d23ea 100644 --- a/internal/httputil/handlers.go +++ b/internal/httputil/handlers.go @@ -1,6 +1,8 @@ package httputil import ( + "bytes" + "encoding/json" "errors" "fmt" "net/http" @@ -27,6 +29,23 @@ func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) { http.Redirect(w, r, url, code) } +// RenderJSON replies to the request with the specified struct as JSON and HTTP code. +// It does not otherwise end the request; the caller should ensure no further +// writes are done to w. +// The error message should be application/json. +func RenderJSON(w http.ResponseWriter, code int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + b := new(bytes.Buffer) + if err := json.NewEncoder(b).Encode(v); err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(b, `{"error":"%s"}`, err) + } else { + w.WriteHeader(code) + } + fmt.Fprint(w, b) +} + // The HandlerFunc type is an adapter to allow the use of // ordinary functions as HTTP handlers. If f is a function // with the appropriate signature, HandlerFunc(f) is a diff --git a/internal/httputil/handlers_test.go b/internal/httputil/handlers_test.go index 2b7cb48be..d5a04f2e8 100644 --- a/internal/httputil/handlers_test.go +++ b/internal/httputil/handlers_test.go @@ -2,6 +2,7 @@ package httputil import ( "errors" + "math" "net/http" "net/http/httptest" "testing" @@ -92,3 +93,58 @@ func TestHandlerFunc_ServeHTTP(t *testing.T) { }) } } + +func TestRenderJSON(t *testing.T) { + + tests := []struct { + name string + code int + v interface{} + wantBody string + wantCode int + }{ + {"simple", + http.StatusTeapot, + struct { + A string + B string + C int + }{ + A: "A", + B: "B", + C: 1, + }, + "{\"A\":\"A\",\"B\":\"B\",\"C\":1}\n", + http.StatusTeapot, + }, + {"map", + http.StatusOK, + map[string]interface{}{ + "C": 1, // notice order does not matter + "A": "A", + "B": "B", + }, + // alphabetical + "{\"A\":\"A\",\"B\":\"B\",\"C\":1}\n", http.StatusOK, + }, + {"bad!", + http.StatusOK, + map[string]interface{}{ + "BAD BOI": math.Inf(1), + }, + `{"error":"json: unsupported value: +Inf"}`, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + + RenderJSON(w, tt.code, tt.v) + if diff := cmp.Diff(tt.wantBody, w.Body.String()); diff != "" { + t.Errorf("TestRenderJSON:\n %s", diff) + } + if diff := cmp.Diff(tt.wantCode, w.Result().StatusCode); diff != "" { + t.Errorf("TestRenderJSON:\n %s", diff) + } + }) + } +} diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index c56b92590..003c0dbc1 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -14,7 +14,6 @@ const ( QuerySession = "pomerium_session" QuerySessionEncrypted = "pomerium_session_encrypted" QueryRedirectURI = "pomerium_redirect_uri" - QueryProgrammaticToken = "pomerium_programmatic_token" QueryForwardAuthURI = "uri" )