authenticate: remove ecjson (#3688)

This commit is contained in:
Caleb Doxsey 2022-10-20 10:37:21 -06:00 committed by GitHub
parent 61506c11b5
commit 75634dfca2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 59 additions and 206 deletions

View file

@ -152,11 +152,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
sharedCipher: sharedCipher,
sessionStore: tt.session,
redirectURL: uriParseHelper("https://some.example"),
sharedEncoder: tt.encoder,
encryptedEncoder: tt.encoder,
sharedCipher: sharedCipher,
sessionStore: tt.session,
redirectURL: uriParseHelper("https://some.example"),
sharedEncoder: tt.encoder,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
@ -308,9 +307,8 @@ func TestAuthenticate_SignOut(t *testing.T) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
encryptedEncoder: mock.Encoder{},
sharedEncoder: mock.Encoder{},
sessionStore: tt.sessionStore,
sharedEncoder: mock.Encoder{},
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
@ -411,10 +409,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
signer, err := jws.NewHS256Signer(nil)
if err != nil {
t.Fatal(err)
}
authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
@ -429,11 +423,10 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
return nil, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
encryptedEncoder: signer,
directoryClient: new(mockDirectoryServiceClient),
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
}),
options: config.NewAtomicOptions(),
}
@ -558,12 +551,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(),
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
cookieCipher: aead,
encryptedEncoder: signer,
sharedEncoder: signer,
cookieSecret: cryptutil.NewKey(),
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
cookieCipher: aead,
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
@ -697,9 +689,8 @@ func TestAuthenticate_userInfo(t *testing.T) {
a := &Authenticate{
options: o,
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
encryptedEncoder: signer,
sharedEncoder: signer,
sessionStore: tt.sessionStore,
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{

View file

@ -11,11 +11,9 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
@ -40,8 +38,6 @@ type authenticateState struct {
cookieSecret []byte
// cookieCipher is the cipher to use to encrypt/decrypt session data
cookieCipher cipher.AEAD
// encryptedEncoder is the encoder used to marshal and unmarshal session data
encryptedEncoder encoding.MarshalUnmarshaler
// sessionStore is the session store used to persist a user's session
sessionStore sessions.SessionStore
// sessionLoaders are a collection of session loaders to attempt to pull
@ -110,10 +106,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
return nil, err
}
state.encryptedEncoder = ecjson.New(state.cookieCipher)
headerStore := header.NewStore(state.encryptedEncoder)
cookieStore, err := cookie.NewStore(func() cookie.Options {
return cookie.Options{
Name: cfg.Options.CookieName,
@ -128,7 +120,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
}
state.sessionStore = cookieStore
state.sessionLoaders = []sessions.SessionLoader{headerStore, cookieStore}
state.sessionLoaders = []sessions.SessionLoader{cookieStore}
state.jwk = new(jose.JSONWebKeySet)
signingKey, err := cfg.Options.GetSigningKey()
if err != nil {

View file

@ -1,121 +0,0 @@
// Package ecjson represents encrypted and compressed content using JSON-based
package ecjson
import (
"bytes"
"compress/gzip"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// 10mb reasonable default?
const maxMemory = int64(10 << 20)
// ErrMessageTooLarge is returned if the data is too large to be processed.
var ErrMessageTooLarge = errors.New("ecjson: message too large")
// EncryptedCompressedJSON implements SecureEncoder for JSON using an AEAD cipher.
//
// See https://en.wikipedia.org/wiki/Authenticated_encryption
type EncryptedCompressedJSON struct {
aead cipher.AEAD
}
// New takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
func New(aead cipher.AEAD) encoding.MarshalUnmarshaler {
return &EncryptedCompressedJSON{aead: aead}
}
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
// and base64 encodes the binary value as a string and returns the result
//
// can panic if source of random entropy is exhausted generating a nonce.
func (c *EncryptedCompressedJSON) Marshal(s interface{}) ([]byte, error) {
// encode json value
plaintext, err := json.Marshal(s)
if err != nil {
return nil, err
}
// compress the plaintext bytes
compressed, err := compress(plaintext)
if err != nil {
return nil, err
}
// encrypt the compressed JSON bytes
ciphertext := cryptutil.Encrypt(c.aead, compressed, nil)
// base64-encode the result
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
return []byte(encoded), nil
}
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
func (c *EncryptedCompressedJSON) Unmarshal(data []byte, s interface{}) error {
// convert base64 string value to bytes
ciphertext, err := base64.RawURLEncoding.DecodeString(string(data))
if err != nil {
return err
}
// decrypt the bytes
compressed, err := cryptutil.Decrypt(c.aead, ciphertext, nil)
if err != nil {
return err
}
// decompress the unencrypted bytes
plaintext, err := decompress(compressed)
if err != nil {
return err
}
// unmarshal the unencrypted bytes
err = json.Unmarshal(plaintext, s)
if err != nil {
return err
}
return nil
}
// compress gzips a set of bytes
func compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %w", err)
}
if writer == nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer")
}
if _, err = writer.Write(data); err != nil {
return nil, fmt.Errorf("cryptutil: failed to compress data with err: %w", err)
}
if err = writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decompress un-gzips a set of bytes
func decompress(data []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %w", err)
}
defer reader.Close()
var buf bytes.Buffer
n, err := io.CopyN(&buf, reader, maxMemory+1)
if err != nil && err != io.EOF {
return nil, err
}
if n > maxMemory {
return nil, ErrMessageTooLarge
}
return buf.Bytes(), nil
}

View file

@ -9,22 +9,21 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestNewStore(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
encoder := ecjson.New(cipher)
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
tests := []struct {
name string
opts *Options
@ -58,11 +57,9 @@ func TestNewStore(t *testing.T) {
}
func TestNewCookieLoader(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
encoder := ecjson.New(cipher)
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
tests := []struct {
name string
opts *Options
@ -96,10 +93,9 @@ func TestNewCookieLoader(t *testing.T) {
}
func TestStore_SaveSession(t *testing.T) {
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil {
@ -113,13 +109,13 @@ func TestStore_SaveSession(t *testing.T) {
wantErr bool
wantLoadErr bool
}{
{"good", &sessions.State{ID: "xyz"}, ecjson.New(c), ecjson.New(c), false, false},
{"good", &sessions.State{ID: "xyz"}, encoder, encoder, false, false},
{"bad cipher", &sessions.State{ID: "xyz"}, nil, nil, true, true},
{"huge cookie", &sessions.State{ID: "xyz", Subject: fmt.Sprintf("%x", hugeString)}, ecjson.New(c), ecjson.New(c), false, false},
{"marshal error", &sessions.State{ID: "xyz"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
{"nil encoder cannot save non string type", &sessions.State{ID: "xyz"}, nil, ecjson.New(c), true, true},
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
{"huge cookie", &sessions.State{ID: "xyz", Subject: fmt.Sprintf("%x", hugeString)}, encoder, encoder, false, false},
{"marshal error", &sessions.State{ID: "xyz"}, mock.Encoder{MarshalError: errors.New("error")}, encoder, true, true},
{"nil encoder cannot save non string type", &sessions.State{ID: "xyz"}, nil, encoder, true, true},
{"good marshal string directly", cryptutil.NewBase64Key(), nil, encoder, false, true},
{"good marshal bytes directly", cryptutil.NewKey(), nil, encoder, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -148,14 +144,13 @@ func TestStore_SaveSession(t *testing.T) {
r.AddCookie(cookie)
}
enc := ecjson.New(c)
jwt, err := s.LoadSession(r)
if (err != nil) != tt.wantLoadErr {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return
}
var state sessions.State
enc.Unmarshal([]byte(jwt), &state)
encoder.Unmarshal([]byte(jwt), &state)
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(sessions.State{}),

View file

@ -7,11 +7,11 @@ import (
"strings"
"testing"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
@ -49,11 +49,9 @@ func TestVerifier(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)

View file

@ -7,11 +7,12 @@ import (
"strings"
"testing"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/google/go-cmp/cmp"
)
func testAuthorizer(next http.Handler) http.Handler {
@ -63,11 +64,9 @@ func TestVerifier(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)

View file

@ -7,11 +7,12 @@ import (
"strings"
"testing"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/google/go-cmp/cmp"
)
func testAuthorizer(next http.Handler) http.Handler {
@ -44,11 +45,9 @@ func TestVerifier(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)