diff --git a/CHANGELOG.md b/CHANGELOG.md index 426977b87..db9c7d0a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,10 @@ - Pomerium and its services will gracefully shutdown on [interrupt signal](http://man7.org/linux/man-pages/man7/signal.7.html). [GH-230] - [Google](https://developers.google.com/identity/protocols/OpenIDConnect) now prompts the user to select a user account (by adding `select_account` to the sign in url). This allows a user who has multiple accounts at the authorization server to select amongst the multiple accounts that they may have current sessions for. +### FIXED + +- Fixed potential race condition when signing requests. [GH-240] + ## v0.1.0 ### NEW diff --git a/internal/cryptutil/encrypt.go b/internal/cryptutil/encrypt.go index fbfca9a09..b974be7d7 100644 --- a/internal/cryptutil/encrypt.go +++ b/internal/cryptutil/encrypt.go @@ -77,7 +77,7 @@ func (c *XChaCha20Cipher) GenerateNonce() []byte { func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) { defer func() { if r := recover(); r != nil { - err = fmt.Errorf("internal/aead: error encrypting bytes: %v", r) + err = fmt.Errorf("cryptutil: error encrypting bytes: %v", r) } }() nonce := c.GenerateNonce() @@ -92,7 +92,7 @@ func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) { // Decrypt a value using XChaCha20-Poly1305 func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) { if len(joined) <= c.aead.NonceSize() { - return nil, fmt.Errorf("internal/aead: invalid input size: %d", len(joined)) + return nil, fmt.Errorf("cryptutil: invalid input size: %d", len(joined)) } // grab out the nonce pivot := len(joined) - c.aead.NonceSize() @@ -161,13 +161,13 @@ func compress(data []byte) ([]byte, error) { var buf bytes.Buffer writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression) if err != nil { - return nil, fmt.Errorf("internal/aead: failed to create a gzip writer: %q", err) + return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %q", err) } if writer == nil { - return nil, fmt.Errorf("internal/aead: failed to create a gzip writer") + return nil, fmt.Errorf("cryptutil: failed to create a gzip writer") } if _, err = writer.Write(data); err != nil { - return nil, fmt.Errorf("internal/aead: failed to compress data with err: %q", err) + return nil, fmt.Errorf("cryptutil: failed to compress data with err: %q", err) } if err = writer.Close(); err != nil { return nil, err @@ -178,7 +178,7 @@ func compress(data []byte) ([]byte, error) { func decompress(data []byte) ([]byte, error) { reader, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { - return nil, fmt.Errorf("internal/aead: failed to create a gzip reader: %q", err) + return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %q", err) } defer reader.Close() var buf bytes.Buffer diff --git a/internal/cryptutil/marshal.go b/internal/cryptutil/marshal.go index a2bfcef67..a906531ac 100644 --- a/internal/cryptutil/marshal.go +++ b/internal/cryptutil/marshal.go @@ -14,7 +14,7 @@ import ( func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) { block, _ := pem.Decode(encodedKey) if block == nil { - return nil, fmt.Errorf("marshal: decoded nil PEM block") + return nil, fmt.Errorf("cryptutil: decoded nil PEM block") } pub, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { @@ -23,7 +23,7 @@ func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) { ecdsaPub, ok := pub.(*ecdsa.PublicKey) if !ok { - return nil, errors.New("marshal: data was not an ECDSA public key") + return nil, errors.New("cryptutil: data was not an ECDSA public key") } return ecdsaPub, nil @@ -53,7 +53,7 @@ func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) { block, encodedKey = pem.Decode(encodedKey) if block == nil { - return nil, fmt.Errorf("failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes) + return nil, fmt.Errorf("cryptutil: failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes) } if block.Type == "EC PRIVATE KEY" { diff --git a/internal/cryptutil/mock_cipher_test.go b/internal/cryptutil/mock_cipher_test.go new file mode 100644 index 000000000..037fc2dff --- /dev/null +++ b/internal/cryptutil/mock_cipher_test.go @@ -0,0 +1,44 @@ +package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" + +import ( + "errors" + "testing" +) + +func TestMockCipher_Unmarshal(t *testing.T) { + e := errors.New("err") + mc := MockCipher{ + EncryptResponse: []byte("EncryptResponse"), + EncryptError: e, + DecryptResponse: []byte("DecryptResponse"), + DecryptError: e, + MarshalResponse: "MarshalResponse", + MarshalError: e, + UnmarshalError: e, + } + b, err := mc.Encrypt([]byte("test")) + if string(b) != "EncryptResponse" { + t.Error("unexpected encrypt response") + } + if err != e { + t.Error("unexpected encrypt error") + } + b, err = mc.Decrypt([]byte("test")) + if string(b) != "DecryptResponse" { + t.Error("unexpected Decrypt response") + } + if err != e { + t.Error("unexpected Decrypt error") + } + s, err := mc.Marshal("test") + if err != e { + t.Error("unexpected Marshal error") + } + if s != "MarshalResponse" { + t.Error("unexpected MarshalResponse error") + } + err = mc.Unmarshal("s", "s") + if err != e { + t.Error("unexpected Unmarshal error") + } +} diff --git a/internal/cryptutil/sign.go b/internal/cryptutil/sign.go index 6687e7610..aa7ce6ac3 100644 --- a/internal/cryptutil/sign.go +++ b/internal/cryptutil/sign.go @@ -1,6 +1,7 @@ package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( "fmt" + "sync" "time" jose "gopkg.in/square/go-jose.v2" @@ -15,6 +16,9 @@ type JWTSigner interface { // ES256Signer is struct containing the required fields to create a ES256 signed JSON Web Tokens type ES256Signer struct { + signer jose.Signer + + mu sync.Mutex // User (sub) is unique, stable identifier for the user. // Use in place of the x-pomerium-authenticated-user-id header. User string `json:"sub,omitempty"` @@ -42,8 +46,6 @@ type ES256Signer struct { // IssuedAt (nbf) is the time is measured in seconds since the UNIX epoch. // Allow 1 minute for skew. NotBefore jwt.NumericDate `json:"nbf,omitempty"` - - signer jose.Signer } // NewES256Signer creates an Elliptic Curve, NIST P-256 (aka secp256r1 aka prime256v1) JWT signer. @@ -56,7 +58,7 @@ type ES256Signer struct { func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) { key, err := DecodePrivateKey(privKey) if err != nil { - return nil, fmt.Errorf("internal/cryptutil parsing key failed %v", err) + return nil, fmt.Errorf("cryptutil: parsing key failed %v", err) } signer, err := jose.NewSigner( jose.SigningKey{ @@ -65,7 +67,7 @@ func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) { }, (&jose.SignerOptions{}).WithType("JWT")) if err != nil { - return nil, fmt.Errorf("internal/cryptutil new signer failed %v", err) + return nil, fmt.Errorf("cryptutil: new signer failed %v", err) } return &ES256Signer{ Issuer: "pomerium-proxy", @@ -77,6 +79,8 @@ func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) { // SignJWT creates a signed JWT containing claims for the logged in // user id (`sub`), email (`email`) and groups (`groups`). func (s *ES256Signer) SignJWT(user, email, groups string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() s.User = user s.Email = email s.Groups = groups @@ -86,7 +90,7 @@ func (s *ES256Signer) SignJWT(user, email, groups string) (string, error) { s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * jwt.DefaultLeeway)) rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize() if err != nil { - return "", err + return "", fmt.Errorf("cryptutil: sign failed %v", err) } return rawJWT, nil }