authorize: get claims from signed jwt (#954)

authorize: get claims from signed jwt

When doing databroker refactoring, all claims information were moved to
signed JWT instead of raw session JWT. But we are still looking for
claims info in raw session JWT, causes all X-Pomerium-Claim-* headers
being gone.

Fix this by looking for information from signed JWT instead.

Note that even with this fix, the X-Pomerium-Claim-Groups is still not
present, but it's another bug (see #941) and will be fixed later.

Fixes #936
This commit is contained in:
Cuong Manh Le 2020-06-22 09:51:32 +07:00 committed by GitHub
parent fbce3dd359
commit 4a3fb5d44b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 31 deletions

View file

@ -18,14 +18,12 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
)
func (a *Authorize) okResponse(
reply *evaluator.Result,
rawSession []byte,
) *envoy_service_auth_v2.CheckResponse {
requestHeaders, err := a.getEnvoyRequestHeaders(rawSession)
func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v2.CheckResponse {
requestHeaders, err := a.getEnvoyRequestHeaders(reply.SignedJWT)
if err != nil {
log.Warn().Err(err).Msg("authorize: error generating new request headers")
}
requestHeaders = append(requestHeaders,
mkHeader(httputil.HeaderPomeriumJWTAssertion, reply.SignedJWT))

View file

@ -4,6 +4,7 @@ package evaluator
import (
"context"
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"fmt"
@ -116,7 +117,7 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
return &deny[0], nil
}
signedJWT, err := e.getSignedJWT(req)
signedJWT, err := e.SignedJWT(req)
if err != nil {
return nil, fmt.Errorf("error signing JWT: %w", err)
}
@ -145,7 +146,17 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
}, nil
}
func (e *Evaluator) getSignedJWT(req *Request) (string, error) {
// ParseSignedJWT parses the input signature and return its payload.
func (e *Evaluator) ParseSignedJWT(signature string) ([]byte, error) {
object, err := jose.ParseSigned(signature)
if err != nil {
return nil, err
}
return object.Verify(&(e.jwk.(*ecdsa.PrivateKey).PublicKey))
}
// SignedJWT returns the signature of given request.
func (e *Evaluator) SignedJWT(req *Request) (string, error) {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.ES256,
Key: e.jwk,
@ -169,6 +180,7 @@ func (e *Evaluator) getSignedJWT(req *Request) (string, error) {
}
if u, ok := req.DataBrokerData.Get("type.googleapis.com/user.User", s.GetUserId()).(*user.User); ok {
payload["sub"] = u.GetId()
payload["user"] = u.GetId()
payload["email"] = u.GetEmail()
}
if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok {

View file

@ -2,10 +2,14 @@ package evaluator
import (
"encoding/json"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/grpc/directory"
)
@ -65,3 +69,31 @@ func TestJSONMarshal(t *testing.T) {
"is_valid_client_certificate": true
}`, string(bs))
}
func TestEvaluator_SignedJWT(t *testing.T) {
opt := config.NewDefaultOptions()
opt.AuthenticateURL = mustParseURL("https://authenticate.example.com")
e, err := New(opt)
require.NoError(t, err)
req := &Request{
HTTP: RequestHTTP{
Method: http.MethodGet,
URL: "https://example.com",
},
}
signedJWT, err := e.SignedJWT(req)
require.NoError(t, err)
assert.NotEmpty(t, signedJWT)
payload, err := e.ParseSignedJWT(signedJWT)
require.NoError(t, err)
assert.NotEmpty(t, payload)
}
func mustParseURL(str string) *url.URL {
u, err := url.Parse(str)
if err != nil {
panic(err)
}
return u
}

View file

@ -75,7 +75,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
switch {
case reply.Status == http.StatusOK:
return a.okResponse(reply, rawJWT), nil
return a.okResponse(reply), nil
case reply.Status == http.StatusUnauthorized:
if isForwardAuth {
return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil
@ -147,10 +147,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
return s
}
func (a *Authorize) getEnvoyRequestHeaders(rawJWT []byte) ([]*envoy_api_v2_core.HeaderValueOption, error) {
func (a *Authorize) getEnvoyRequestHeaders(signedJWT string) ([]*envoy_api_v2_core.HeaderValueOption, error) {
var hvos []*envoy_api_v2_core.HeaderValueOption
hdrs, err := getJWTClaimHeaders(a.currentOptions.Load(), a.currentEncoder.Load(), rawJWT)
hdrs, err := a.getJWTClaimHeaders(a.currentOptions.Load(), signedJWT)
if err != nil {
return nil, err
}

View file

@ -85,41 +85,38 @@ func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (m
return hdrs, nil
}
func getJWTClaimHeaders(options config.Options, encoder encoding.MarshalUnmarshaler, rawjwt []byte) (map[string]string, error) {
if len(rawjwt) == 0 {
func (a *Authorize) getJWTClaimHeaders(options config.Options, signedJWT string) (map[string]string, error) {
if len(signedJWT) == 0 {
return make(map[string]string), nil
}
var claims map[string]jwtClaim
err := encoder.Unmarshal(rawjwt, &claims)
var claims map[string]interface{}
payload, err := a.pe.ParseSignedJWT(signedJWT)
if err != nil {
return nil, err
}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, err
}
hdrs := make(map[string]string)
for _, name := range options.JWTClaimsHeaders {
if claim, ok := claims[name]; ok {
hdrs["x-pomerium-claim-"+name] = strings.Join(claim, ",")
switch value := claim.(type) {
case string:
hdrs["x-pomerium-claim-"+name] = value
case []interface{}:
hdrs["x-pomerium-claim-"+name] = strings.Join(toSliceStrings(value), ",")
}
}
}
return hdrs, nil
}
type jwtClaim []string
func (claim *jwtClaim) UnmarshalJSON(bs []byte) error {
var raw interface{}
err := json.Unmarshal(bs, &raw)
if err != nil {
return err
func toSliceStrings(sliceIfaces []interface{}) []string {
sliceStrings := make([]string, 0, len(sliceIfaces))
for _, e := range sliceIfaces {
sliceStrings = append(sliceStrings, fmt.Sprint(e))
}
switch obj := raw.(type) {
case []interface{}:
for _, el := range obj {
*claim = append(*claim, fmt.Sprint(el))
}
default:
*claim = append(*claim, fmt.Sprint(obj))
}
return nil
return sliceStrings
}