diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 978e0a5c0..e2903eaf2 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -152,11 +152,10 @@ func TestAuthenticate_SignIn(t *testing.T) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ - sharedCipher: sharedCipher, - sessionStore: tt.session, - redirectURL: uriParseHelper("https://some.example"), - sharedEncoder: tt.encoder, - encryptedEncoder: tt.encoder, + sharedCipher: sharedCipher, + sessionStore: tt.session, + redirectURL: uriParseHelper("https://some.example"), + sharedEncoder: tt.encoder, dataBrokerClient: mockDataBrokerServiceClient{ get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { return &databroker.GetResponse{ @@ -308,9 +307,8 @@ func TestAuthenticate_SignOut(t *testing.T) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ - sessionStore: tt.sessionStore, - encryptedEncoder: mock.Encoder{}, - sharedEncoder: mock.Encoder{}, + sessionStore: tt.sessionStore, + sharedEncoder: mock.Encoder{}, dataBrokerClient: mockDataBrokerServiceClient{ get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { return &databroker.GetResponse{ @@ -411,10 +409,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { if err != nil { t.Fatal(err) } - signer, err := jws.NewHS256Signer(nil) - if err != nil { - t.Fatal(err) - } authURL, _ := url.Parse(tt.authenticateURL) a := &Authenticate{ cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { @@ -429,11 +423,10 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { return nil, nil }, }, - directoryClient: new(mockDirectoryServiceClient), - redirectURL: authURL, - sessionStore: tt.session, - cookieCipher: aead, - encryptedEncoder: signer, + directoryClient: new(mockDirectoryServiceClient), + redirectURL: authURL, + sessionStore: tt.session, + cookieCipher: aead, }), options: config.NewAtomicOptions(), } @@ -558,12 +551,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ - cookieSecret: cryptutil.NewKey(), - redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), - sessionStore: tt.session, - cookieCipher: aead, - encryptedEncoder: signer, - sharedEncoder: signer, + cookieSecret: cryptutil.NewKey(), + redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), + sessionStore: tt.session, + cookieCipher: aead, + sharedEncoder: signer, dataBrokerClient: mockDataBrokerServiceClient{ get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { return &databroker.GetResponse{ @@ -697,9 +689,8 @@ func TestAuthenticate_userInfo(t *testing.T) { a := &Authenticate{ options: o, state: atomicutil.NewValue(&authenticateState{ - sessionStore: tt.sessionStore, - encryptedEncoder: signer, - sharedEncoder: signer, + sessionStore: tt.sessionStore, + sharedEncoder: signer, dataBrokerClient: mockDataBrokerServiceClient{ get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { return &databroker.GetResponse{ diff --git a/authenticate/state.go b/authenticate/state.go index 0ef549cb7..0b6b7924a 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -11,11 +11,9 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/encoding/ecjson" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" - "github.com/pomerium/pomerium/internal/sessions/header" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" @@ -40,8 +38,6 @@ type authenticateState struct { cookieSecret []byte // cookieCipher is the cipher to use to encrypt/decrypt session data cookieCipher cipher.AEAD - // encryptedEncoder is the encoder used to marshal and unmarshal session data - encryptedEncoder encoding.MarshalUnmarshaler // sessionStore is the session store used to persist a user's session sessionStore sessions.SessionStore // sessionLoaders are a collection of session loaders to attempt to pull @@ -110,10 +106,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err return nil, err } - state.encryptedEncoder = ecjson.New(state.cookieCipher) - - headerStore := header.NewStore(state.encryptedEncoder) - cookieStore, err := cookie.NewStore(func() cookie.Options { return cookie.Options{ Name: cfg.Options.CookieName, @@ -128,7 +120,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err } state.sessionStore = cookieStore - state.sessionLoaders = []sessions.SessionLoader{headerStore, cookieStore} + state.sessionLoaders = []sessions.SessionLoader{cookieStore} state.jwk = new(jose.JSONWebKeySet) signingKey, err := cfg.Options.GetSigningKey() if err != nil { diff --git a/internal/encoding/ecjson/ecjson.go b/internal/encoding/ecjson/ecjson.go deleted file mode 100644 index 61bb43628..000000000 --- a/internal/encoding/ecjson/ecjson.go +++ /dev/null @@ -1,121 +0,0 @@ -// Package ecjson represents encrypted and compressed content using JSON-based -package ecjson - -import ( - "bytes" - "compress/gzip" - "crypto/cipher" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - - "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/pkg/cryptutil" -) - -// 10mb reasonable default? -const maxMemory = int64(10 << 20) - -// ErrMessageTooLarge is returned if the data is too large to be processed. -var ErrMessageTooLarge = errors.New("ecjson: message too large") - -// EncryptedCompressedJSON implements SecureEncoder for JSON using an AEAD cipher. -// -// See https://en.wikipedia.org/wiki/Authenticated_encryption -type EncryptedCompressedJSON struct { - aead cipher.AEAD -} - -// New takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher. -func New(aead cipher.AEAD) encoding.MarshalUnmarshaler { - return &EncryptedCompressedJSON{aead: aead} -} - -// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher -// and base64 encodes the binary value as a string and returns the result -// -// can panic if source of random entropy is exhausted generating a nonce. -func (c *EncryptedCompressedJSON) Marshal(s interface{}) ([]byte, error) { - // encode json value - plaintext, err := json.Marshal(s) - if err != nil { - return nil, err - } - - // compress the plaintext bytes - compressed, err := compress(plaintext) - if err != nil { - return nil, err - } - // encrypt the compressed JSON bytes - ciphertext := cryptutil.Encrypt(c.aead, compressed, nil) - - // base64-encode the result - encoded := base64.RawURLEncoding.EncodeToString(ciphertext) - return []byte(encoded), nil -} - -// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the -// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed -func (c *EncryptedCompressedJSON) Unmarshal(data []byte, s interface{}) error { - // convert base64 string value to bytes - ciphertext, err := base64.RawURLEncoding.DecodeString(string(data)) - if err != nil { - return err - } - // decrypt the bytes - compressed, err := cryptutil.Decrypt(c.aead, ciphertext, nil) - if err != nil { - return err - } - // decompress the unencrypted bytes - plaintext, err := decompress(compressed) - if err != nil { - return err - } - // unmarshal the unencrypted bytes - err = json.Unmarshal(plaintext, s) - if err != nil { - return err - } - return nil -} - -// compress gzips a set of bytes -func compress(data []byte) ([]byte, error) { - var buf bytes.Buffer - writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression) - if err != nil { - return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %w", err) - } - if writer == nil { - return nil, fmt.Errorf("cryptutil: failed to create a gzip writer") - } - if _, err = writer.Write(data); err != nil { - return nil, fmt.Errorf("cryptutil: failed to compress data with err: %w", err) - } - if err = writer.Close(); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -// decompress un-gzips a set of bytes -func decompress(data []byte) ([]byte, error) { - reader, err := gzip.NewReader(bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %w", err) - } - defer reader.Close() - var buf bytes.Buffer - n, err := io.CopyN(&buf, reader, maxMemory+1) - if err != nil && err != io.EOF { - return nil, err - } - if n > maxMemory { - return nil, ErrMessageTooLarge - } - return buf.Bytes(), nil -} diff --git a/internal/sessions/cookie/cookie_store_test.go b/internal/sessions/cookie/cookie_store_test.go index ad9c2b5c5..a92256d70 100644 --- a/internal/sessions/cookie/cookie_store_test.go +++ b/internal/sessions/cookie/cookie_store_test.go @@ -9,22 +9,21 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" + "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/cryptutil" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" ) func TestNewStore(t *testing.T) { - cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) - if err != nil { - t.Fatal(err) - } - encoder := ecjson.New(cipher) + key := cryptutil.NewKey() + encoder, err := jws.NewHS256Signer(key) + require.NoError(t, err) tests := []struct { name string opts *Options @@ -58,11 +57,9 @@ func TestNewStore(t *testing.T) { } func TestNewCookieLoader(t *testing.T) { - cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) - if err != nil { - t.Fatal(err) - } - encoder := ecjson.New(cipher) + key := cryptutil.NewKey() + encoder, err := jws.NewHS256Signer(key) + require.NoError(t, err) tests := []struct { name string opts *Options @@ -96,10 +93,9 @@ func TestNewCookieLoader(t *testing.T) { } func TestStore_SaveSession(t *testing.T) { - c, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) - if err != nil { - t.Fatal(err) - } + key := cryptutil.NewKey() + encoder, err := jws.NewHS256Signer(key) + require.NoError(t, err) hugeString := make([]byte, 4097) if _, err := rand.Read(hugeString); err != nil { @@ -113,13 +109,13 @@ func TestStore_SaveSession(t *testing.T) { wantErr bool wantLoadErr bool }{ - {"good", &sessions.State{ID: "xyz"}, ecjson.New(c), ecjson.New(c), false, false}, + {"good", &sessions.State{ID: "xyz"}, encoder, encoder, false, false}, {"bad cipher", &sessions.State{ID: "xyz"}, nil, nil, true, true}, - {"huge cookie", &sessions.State{ID: "xyz", Subject: fmt.Sprintf("%x", hugeString)}, ecjson.New(c), ecjson.New(c), false, false}, - {"marshal error", &sessions.State{ID: "xyz"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true}, - {"nil encoder cannot save non string type", &sessions.State{ID: "xyz"}, nil, ecjson.New(c), true, true}, - {"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true}, - {"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true}, + {"huge cookie", &sessions.State{ID: "xyz", Subject: fmt.Sprintf("%x", hugeString)}, encoder, encoder, false, false}, + {"marshal error", &sessions.State{ID: "xyz"}, mock.Encoder{MarshalError: errors.New("error")}, encoder, true, true}, + {"nil encoder cannot save non string type", &sessions.State{ID: "xyz"}, nil, encoder, true, true}, + {"good marshal string directly", cryptutil.NewBase64Key(), nil, encoder, false, true}, + {"good marshal bytes directly", cryptutil.NewKey(), nil, encoder, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -148,14 +144,13 @@ func TestStore_SaveSession(t *testing.T) { r.AddCookie(cookie) } - enc := ecjson.New(c) jwt, err := s.LoadSession(r) if (err != nil) != tt.wantLoadErr { t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr) return } var state sessions.State - enc.Unmarshal([]byte(jwt), &state) + encoder.Unmarshal([]byte(jwt), &state) cmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(sessions.State{}), diff --git a/internal/sessions/cookie/middleware_test.go b/internal/sessions/cookie/middleware_test.go index 7cfacc6cc..3353d0965 100644 --- a/internal/sessions/cookie/middleware_test.go +++ b/internal/sessions/cookie/middleware_test.go @@ -7,11 +7,11 @@ import ( "strings" "testing" - "github.com/pomerium/pomerium/internal/sessions" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" - "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -49,11 +49,9 @@ func TestVerifier(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) - encoder := ecjson.New(cipher) - if err != nil { - t.Fatal(err) - } + key := cryptutil.NewKey() + encoder, err := jws.NewHS256Signer(key) + require.NoError(t, err) encSession, err := encoder.Marshal(&tt.state) if err != nil { t.Fatal(err) diff --git a/internal/sessions/header/middleware_test.go b/internal/sessions/header/middleware_test.go index 066fce646..13e0d2871 100644 --- a/internal/sessions/header/middleware_test.go +++ b/internal/sessions/header/middleware_test.go @@ -7,11 +7,12 @@ import ( "strings" "testing" - "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/cryptutil" - - "github.com/google/go-cmp/cmp" ) func testAuthorizer(next http.Handler) http.Handler { @@ -63,11 +64,9 @@ func TestVerifier(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) - encoder := ecjson.New(cipher) - if err != nil { - t.Fatal(err) - } + key := cryptutil.NewKey() + encoder, err := jws.NewHS256Signer(key) + require.NoError(t, err) encSession, err := encoder.Marshal(&tt.state) if err != nil { t.Fatal(err) diff --git a/internal/sessions/queryparam/middleware_test.go b/internal/sessions/queryparam/middleware_test.go index 54a6c6454..b18f68a21 100644 --- a/internal/sessions/queryparam/middleware_test.go +++ b/internal/sessions/queryparam/middleware_test.go @@ -7,11 +7,12 @@ import ( "strings" "testing" - "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/cryptutil" - - "github.com/google/go-cmp/cmp" ) func testAuthorizer(next http.Handler) http.Handler { @@ -44,11 +45,9 @@ func TestVerifier(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) - encoder := ecjson.New(cipher) - if err != nil { - t.Fatal(err) - } + key := cryptutil.NewKey() + encoder, err := jws.NewHS256Signer(key) + require.NoError(t, err) encSession, err := encoder.Marshal(&tt.state) if err != nil { t.Fatal(err)