diff --git a/authorize/check_response.go b/authorize/check_response.go index bc1b2aec4..91e71a533 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -18,14 +18,12 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -func (a *Authorize) okResponse( - reply *evaluator.Result, - rawSession []byte, -) *envoy_service_auth_v2.CheckResponse { - requestHeaders, err := a.getEnvoyRequestHeaders(rawSession) +func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v2.CheckResponse { + requestHeaders, err := a.getEnvoyRequestHeaders(reply.SignedJWT) if err != nil { log.Warn().Err(err).Msg("authorize: error generating new request headers") } + requestHeaders = append(requestHeaders, mkHeader(httputil.HeaderPomeriumJWTAssertion, reply.SignedJWT)) diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index dd9e4d097..7c07a2671 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -4,6 +4,7 @@ package evaluator import ( "context" + "crypto/ecdsa" "encoding/base64" "encoding/json" "fmt" @@ -116,7 +117,7 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error) return &deny[0], nil } - signedJWT, err := e.getSignedJWT(req) + signedJWT, err := e.SignedJWT(req) if err != nil { return nil, fmt.Errorf("error signing JWT: %w", err) } @@ -145,7 +146,17 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error) }, nil } -func (e *Evaluator) getSignedJWT(req *Request) (string, error) { +// ParseSignedJWT parses the input signature and return its payload. +func (e *Evaluator) ParseSignedJWT(signature string) ([]byte, error) { + object, err := jose.ParseSigned(signature) + if err != nil { + return nil, err + } + return object.Verify(&(e.jwk.(*ecdsa.PrivateKey).PublicKey)) +} + +// SignedJWT returns the signature of given request. +func (e *Evaluator) SignedJWT(req *Request) (string, error) { signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.ES256, Key: e.jwk, @@ -169,6 +180,7 @@ func (e *Evaluator) getSignedJWT(req *Request) (string, error) { } if u, ok := req.DataBrokerData.Get("type.googleapis.com/user.User", s.GetUserId()).(*user.User); ok { payload["sub"] = u.GetId() + payload["user"] = u.GetId() payload["email"] = u.GetEmail() } if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok { diff --git a/authorize/evaluator/evaluator_test.go b/authorize/evaluator/evaluator_test.go index 705dc06bf..ab45ed035 100644 --- a/authorize/evaluator/evaluator_test.go +++ b/authorize/evaluator/evaluator_test.go @@ -2,10 +2,14 @@ package evaluator import ( "encoding/json" + "net/http" + "net/url" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/grpc/directory" ) @@ -65,3 +69,31 @@ func TestJSONMarshal(t *testing.T) { "is_valid_client_certificate": true }`, string(bs)) } + +func TestEvaluator_SignedJWT(t *testing.T) { + opt := config.NewDefaultOptions() + opt.AuthenticateURL = mustParseURL("https://authenticate.example.com") + e, err := New(opt) + require.NoError(t, err) + req := &Request{ + HTTP: RequestHTTP{ + Method: http.MethodGet, + URL: "https://example.com", + }, + } + signedJWT, err := e.SignedJWT(req) + require.NoError(t, err) + assert.NotEmpty(t, signedJWT) + + payload, err := e.ParseSignedJWT(signedJWT) + require.NoError(t, err) + assert.NotEmpty(t, payload) +} + +func mustParseURL(str string) *url.URL { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + return u +} diff --git a/authorize/grpc.go b/authorize/grpc.go index 0198fe876..a15378b45 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -75,7 +75,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe switch { case reply.Status == http.StatusOK: - return a.okResponse(reply, rawJWT), nil + return a.okResponse(reply), nil case reply.Status == http.StatusUnauthorized: if isForwardAuth { return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil @@ -147,10 +147,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User return s } -func (a *Authorize) getEnvoyRequestHeaders(rawJWT []byte) ([]*envoy_api_v2_core.HeaderValueOption, error) { +func (a *Authorize) getEnvoyRequestHeaders(signedJWT string) ([]*envoy_api_v2_core.HeaderValueOption, error) { var hvos []*envoy_api_v2_core.HeaderValueOption - hdrs, err := getJWTClaimHeaders(a.currentOptions.Load(), a.currentEncoder.Load(), rawJWT) + hdrs, err := a.getJWTClaimHeaders(a.currentOptions.Load(), signedJWT) if err != nil { return nil, err } diff --git a/authorize/session.go b/authorize/session.go index a6f04a6aa..67d663320 100644 --- a/authorize/session.go +++ b/authorize/session.go @@ -85,41 +85,38 @@ func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (m return hdrs, nil } -func getJWTClaimHeaders(options config.Options, encoder encoding.MarshalUnmarshaler, rawjwt []byte) (map[string]string, error) { - if len(rawjwt) == 0 { +func (a *Authorize) getJWTClaimHeaders(options config.Options, signedJWT string) (map[string]string, error) { + if len(signedJWT) == 0 { return make(map[string]string), nil } - var claims map[string]jwtClaim - err := encoder.Unmarshal(rawjwt, &claims) + var claims map[string]interface{} + payload, err := a.pe.ParseSignedJWT(signedJWT) if err != nil { return nil, err } + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, err + } hdrs := make(map[string]string) for _, name := range options.JWTClaimsHeaders { if claim, ok := claims[name]; ok { - hdrs["x-pomerium-claim-"+name] = strings.Join(claim, ",") + switch value := claim.(type) { + case string: + hdrs["x-pomerium-claim-"+name] = value + case []interface{}: + hdrs["x-pomerium-claim-"+name] = strings.Join(toSliceStrings(value), ",") + } } } return hdrs, nil } -type jwtClaim []string - -func (claim *jwtClaim) UnmarshalJSON(bs []byte) error { - var raw interface{} - err := json.Unmarshal(bs, &raw) - if err != nil { - return err +func toSliceStrings(sliceIfaces []interface{}) []string { + sliceStrings := make([]string, 0, len(sliceIfaces)) + for _, e := range sliceIfaces { + sliceStrings = append(sliceStrings, fmt.Sprint(e)) } - switch obj := raw.(type) { - case []interface{}: - for _, el := range obj { - *claim = append(*claim, fmt.Sprint(el)) - } - default: - *claim = append(*claim, fmt.Sprint(obj)) - } - return nil + return sliceStrings }