authorize: get claims from signed jwt (#954)

authorize: get claims from signed jwt

When doing databroker refactoring, all claims information were moved to
signed JWT instead of raw session JWT. But we are still looking for
claims info in raw session JWT, causes all X-Pomerium-Claim-* headers
being gone.

Fix this by looking for information from signed JWT instead.

Note that even with this fix, the X-Pomerium-Claim-Groups is still not
present, but it's another bug (see #941) and will be fixed later.

Fixes #936
This commit is contained in:
Cuong Manh Le 2020-06-22 09:51:32 +07:00 committed by GitHub
parent fbce3dd359
commit 4a3fb5d44b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 31 deletions

View file

@ -18,14 +18,12 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
func (a *Authorize) okResponse( func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v2.CheckResponse {
reply *evaluator.Result, requestHeaders, err := a.getEnvoyRequestHeaders(reply.SignedJWT)
rawSession []byte,
) *envoy_service_auth_v2.CheckResponse {
requestHeaders, err := a.getEnvoyRequestHeaders(rawSession)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("authorize: error generating new request headers") log.Warn().Err(err).Msg("authorize: error generating new request headers")
} }
requestHeaders = append(requestHeaders, requestHeaders = append(requestHeaders,
mkHeader(httputil.HeaderPomeriumJWTAssertion, reply.SignedJWT)) mkHeader(httputil.HeaderPomeriumJWTAssertion, reply.SignedJWT))

View file

@ -4,6 +4,7 @@ package evaluator
import ( import (
"context" "context"
"crypto/ecdsa"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -116,7 +117,7 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
return &deny[0], nil return &deny[0], nil
} }
signedJWT, err := e.getSignedJWT(req) signedJWT, err := e.SignedJWT(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("error signing JWT: %w", err) return nil, fmt.Errorf("error signing JWT: %w", err)
} }
@ -145,7 +146,17 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
}, nil }, nil
} }
func (e *Evaluator) getSignedJWT(req *Request) (string, error) { // ParseSignedJWT parses the input signature and return its payload.
func (e *Evaluator) ParseSignedJWT(signature string) ([]byte, error) {
object, err := jose.ParseSigned(signature)
if err != nil {
return nil, err
}
return object.Verify(&(e.jwk.(*ecdsa.PrivateKey).PublicKey))
}
// SignedJWT returns the signature of given request.
func (e *Evaluator) SignedJWT(req *Request) (string, error) {
signer, err := jose.NewSigner(jose.SigningKey{ signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.ES256, Algorithm: jose.ES256,
Key: e.jwk, Key: e.jwk,
@ -169,6 +180,7 @@ func (e *Evaluator) getSignedJWT(req *Request) (string, error) {
} }
if u, ok := req.DataBrokerData.Get("type.googleapis.com/user.User", s.GetUserId()).(*user.User); ok { if u, ok := req.DataBrokerData.Get("type.googleapis.com/user.User", s.GetUserId()).(*user.User); ok {
payload["sub"] = u.GetId() payload["sub"] = u.GetId()
payload["user"] = u.GetId()
payload["email"] = u.GetEmail() payload["email"] = u.GetEmail()
} }
if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok { if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok {

View file

@ -2,10 +2,14 @@ package evaluator
import ( import (
"encoding/json" "encoding/json"
"net/http"
"net/url"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
) )
@ -65,3 +69,31 @@ func TestJSONMarshal(t *testing.T) {
"is_valid_client_certificate": true "is_valid_client_certificate": true
}`, string(bs)) }`, string(bs))
} }
func TestEvaluator_SignedJWT(t *testing.T) {
opt := config.NewDefaultOptions()
opt.AuthenticateURL = mustParseURL("https://authenticate.example.com")
e, err := New(opt)
require.NoError(t, err)
req := &Request{
HTTP: RequestHTTP{
Method: http.MethodGet,
URL: "https://example.com",
},
}
signedJWT, err := e.SignedJWT(req)
require.NoError(t, err)
assert.NotEmpty(t, signedJWT)
payload, err := e.ParseSignedJWT(signedJWT)
require.NoError(t, err)
assert.NotEmpty(t, payload)
}
func mustParseURL(str string) *url.URL {
u, err := url.Parse(str)
if err != nil {
panic(err)
}
return u
}

View file

@ -75,7 +75,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
switch { switch {
case reply.Status == http.StatusOK: case reply.Status == http.StatusOK:
return a.okResponse(reply, rawJWT), nil return a.okResponse(reply), nil
case reply.Status == http.StatusUnauthorized: case reply.Status == http.StatusUnauthorized:
if isForwardAuth { if isForwardAuth {
return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil
@ -147,10 +147,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
return s return s
} }
func (a *Authorize) getEnvoyRequestHeaders(rawJWT []byte) ([]*envoy_api_v2_core.HeaderValueOption, error) { func (a *Authorize) getEnvoyRequestHeaders(signedJWT string) ([]*envoy_api_v2_core.HeaderValueOption, error) {
var hvos []*envoy_api_v2_core.HeaderValueOption var hvos []*envoy_api_v2_core.HeaderValueOption
hdrs, err := getJWTClaimHeaders(a.currentOptions.Load(), a.currentEncoder.Load(), rawJWT) hdrs, err := a.getJWTClaimHeaders(a.currentOptions.Load(), signedJWT)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -85,41 +85,38 @@ func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (m
return hdrs, nil return hdrs, nil
} }
func getJWTClaimHeaders(options config.Options, encoder encoding.MarshalUnmarshaler, rawjwt []byte) (map[string]string, error) { func (a *Authorize) getJWTClaimHeaders(options config.Options, signedJWT string) (map[string]string, error) {
if len(rawjwt) == 0 { if len(signedJWT) == 0 {
return make(map[string]string), nil return make(map[string]string), nil
} }
var claims map[string]jwtClaim var claims map[string]interface{}
err := encoder.Unmarshal(rawjwt, &claims) payload, err := a.pe.ParseSignedJWT(signedJWT)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, err
}
hdrs := make(map[string]string) hdrs := make(map[string]string)
for _, name := range options.JWTClaimsHeaders { for _, name := range options.JWTClaimsHeaders {
if claim, ok := claims[name]; ok { if claim, ok := claims[name]; ok {
hdrs["x-pomerium-claim-"+name] = strings.Join(claim, ",") switch value := claim.(type) {
case string:
hdrs["x-pomerium-claim-"+name] = value
case []interface{}:
hdrs["x-pomerium-claim-"+name] = strings.Join(toSliceStrings(value), ",")
}
} }
} }
return hdrs, nil return hdrs, nil
} }
type jwtClaim []string func toSliceStrings(sliceIfaces []interface{}) []string {
sliceStrings := make([]string, 0, len(sliceIfaces))
func (claim *jwtClaim) UnmarshalJSON(bs []byte) error { for _, e := range sliceIfaces {
var raw interface{} sliceStrings = append(sliceStrings, fmt.Sprint(e))
err := json.Unmarshal(bs, &raw)
if err != nil {
return err
} }
switch obj := raw.(type) { return sliceStrings
case []interface{}:
for _, el := range obj {
*claim = append(*claim, fmt.Sprint(el))
}
default:
*claim = append(*claim, fmt.Sprint(obj))
}
return nil
} }