mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
internal/cryputil: combines aead and cryptutil packages.
- Refactored encrypt / decrypt methods to use aead's NonceSize() interface method. - Add explicit GenerateKey function. - Remove mutex on XChaCha20.
This commit is contained in:
parent
131810ccfe
commit
24b11b0428
11 changed files with 44 additions and 89 deletions
|
@ -12,7 +12,7 @@ import (
|
||||||
"github.com/pomerium/envconfig"
|
"github.com/pomerium/envconfig"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/authenticate/providers"
|
"github.com/pomerium/pomerium/authenticate/providers"
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
)
|
)
|
||||||
|
@ -132,7 +132,7 @@ type Authenticator struct {
|
||||||
// sesion related
|
// sesion related
|
||||||
csrfStore sessions.CSRFStore
|
csrfStore sessions.CSRFStore
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
cipher aead.Cipher
|
cipher cryptutil.Cipher
|
||||||
|
|
||||||
provider providers.Provider
|
provider providers.Provider
|
||||||
}
|
}
|
||||||
|
@ -149,7 +149,7 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cipher, err := aead.New([]byte(decodedAuthCodeSecret))
|
cipher, err := cryptutil.NewCipher([]byte(decodedAuthCodeSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
m "github.com/pomerium/pomerium/internal/middleware"
|
m "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
@ -339,7 +339,7 @@ func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, m
|
||||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
|
||||||
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
nonce := fmt.Sprintf("%x", aead.GenerateKey())
|
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||||
p.csrfStore.SetCSRF(rw, req, nonce)
|
p.csrfStore.SetCSRF(rw, req, nonce)
|
||||||
|
|
||||||
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
|
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockCipher is a mock of the cipher interface
|
|
||||||
type MockCipher struct {
|
|
||||||
MarshalError error
|
|
||||||
MarshalString string
|
|
||||||
UnmarshalError error
|
|
||||||
UnmarshalBytes []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt returns an empty byte array and nil
|
|
||||||
func (mc *MockCipher) Encrypt([]byte) ([]byte, error) {
|
|
||||||
return []byte{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decrypt returns an empty byte array and nil
|
|
||||||
func (mc *MockCipher) Decrypt([]byte) ([]byte, error) {
|
|
||||||
return []byte{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal returns the marshal string and marsha error
|
|
||||||
func (mc *MockCipher) Marshal(interface{}) (string, error) {
|
|
||||||
return mc.MarshalString, mc.MarshalError
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal unmarshals the unmarshal bytes to be set in s and returns the unmarshal error
|
|
||||||
func (mc *MockCipher) Unmarshal(b string, s interface{}) error {
|
|
||||||
json.Unmarshal(mc.UnmarshalBytes, s)
|
|
||||||
return mc.UnmarshalError
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -9,11 +9,20 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GenerateKey generates a random 32-byte key.
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func GenerateKey() []byte {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
// Cipher provides methods to encrypt and decrypt values.
|
// Cipher provides methods to encrypt and decrypt values.
|
||||||
type Cipher interface {
|
type Cipher interface {
|
||||||
Encrypt([]byte) ([]byte, error)
|
Encrypt([]byte) ([]byte, error)
|
||||||
|
@ -27,12 +36,10 @@ type Cipher interface {
|
||||||
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
|
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
|
||||||
type XChaCha20Cipher struct {
|
type XChaCha20Cipher struct {
|
||||||
aead cipher.AEAD
|
aead cipher.AEAD
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new AES Cipher for encrypting values
|
// NewCipher returns a new XChacha20poly1305 cipher.
|
||||||
func New(secret []byte) (*XChaCha20Cipher, error) {
|
func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
|
||||||
aead, err := chacha20poly1305.NewX(secret)
|
aead, err := chacha20poly1305.NewX(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -42,20 +49,10 @@ func New(secret []byte) (*XChaCha20Cipher, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateKey generates a random 32-byte encryption key.
|
// GenerateNonce generates a random nonce.
|
||||||
// Panics if the key size is unsupported or source of randomness fails.
|
// Panics if source of randomness fails.
|
||||||
func GenerateKey() []byte {
|
func (c *XChaCha20Cipher) GenerateNonce() []byte {
|
||||||
nonce := make([]byte, chacha20poly1305.KeySize)
|
nonce := make([]byte, c.aead.NonceSize())
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return nonce
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateNonce generates a random 24-byte nonce for XChaCha20-Poly1305.
|
|
||||||
// Panics if the key size is unsupported or source of randomness fails.
|
|
||||||
func GenerateNonce() []byte {
|
|
||||||
nonce := make([]byte, chacha20poly1305.NonceSizeX)
|
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -64,15 +61,12 @@ func GenerateNonce() []byte {
|
||||||
|
|
||||||
// Encrypt a value using XChaCha20-Poly1305
|
// Encrypt a value using XChaCha20-Poly1305
|
||||||
func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) {
|
func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = fmt.Errorf("internal/aead: error encrypting bytes: %v", r)
|
err = fmt.Errorf("internal/aead: error encrypting bytes: %v", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
nonce := GenerateNonce()
|
nonce := c.GenerateNonce()
|
||||||
|
|
||||||
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
||||||
|
|
||||||
|
@ -83,14 +77,11 @@ func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) {
|
||||||
|
|
||||||
// Decrypt a value using XChaCha20-Poly1305
|
// Decrypt a value using XChaCha20-Poly1305
|
||||||
func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) {
|
func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) {
|
||||||
c.mu.Lock()
|
if len(joined) <= c.aead.NonceSize() {
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
if len(joined) <= chacha20poly1305.NonceSizeX {
|
|
||||||
return nil, fmt.Errorf("internal/aead: invalid input size: %d", len(joined))
|
return nil, fmt.Errorf("internal/aead: invalid input size: %d", len(joined))
|
||||||
}
|
}
|
||||||
// grab out the nonce
|
// grab out the nonce
|
||||||
pivot := len(joined) - chacha20poly1305.NonceSizeX
|
pivot := len(joined) - c.aead.NonceSize()
|
||||||
ciphertext := joined[:pivot]
|
ciphertext := joined[:pivot]
|
||||||
nonce := joined[pivot:]
|
nonce := joined[pivot:]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package aead // import "github.com/pomerium/pomerium/internal/aead"
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -13,7 +13,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
plaintext := []byte("my plain text value")
|
plaintext := []byte("my plain text value")
|
||||||
|
|
||||||
key := GenerateKey()
|
key := GenerateKey()
|
||||||
c, err := New([]byte(key))
|
c, err := NewCipher([]byte(key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
key := GenerateKey()
|
key := GenerateKey()
|
||||||
|
|
||||||
c, err := New([]byte(key))
|
c, err := NewCipher([]byte(key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCipherDataRace(t *testing.T) {
|
func TestCipherDataRace(t *testing.T) {
|
||||||
miscreantCipher, err := New(GenerateKey())
|
cipher, err := NewCipher(GenerateKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected generating cipher err: %v", err)
|
t.Fatalf("unexpected generating cipher err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -158,7 +158,7 @@ func TestCipherDataRace(t *testing.T) {
|
||||||
t.Fatalf("expected structs to be equal")
|
t.Fatalf("expected structs to be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
}(miscreantCipher, wg)
|
}(cipher, wg)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrInvalidSession is an error for invalid sessions.
|
// ErrInvalidSession is an error for invalid sessions.
|
||||||
|
@ -36,14 +36,14 @@ type CookieStore struct {
|
||||||
CookieSecure bool
|
CookieSecure bool
|
||||||
CookieHTTPOnly bool
|
CookieHTTPOnly bool
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
CookieCipher aead.Cipher
|
CookieCipher cryptutil.Cipher
|
||||||
SessionLifetimeTTL time.Duration
|
SessionLifetimeTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
|
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
|
||||||
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
|
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
|
||||||
return func(s *CookieStore) error {
|
return func(s *CookieStore) error {
|
||||||
cipher, err := aead.New(cookieSecret)
|
cipher, err := cryptutil.NewCipher(cookieSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
|
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -48,13 +48,13 @@ func isExpired(t time.Time) bool {
|
||||||
|
|
||||||
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||||
// given cipher, and base64-encodes the result
|
// given cipher, and base64-encodes the result
|
||||||
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) {
|
func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) {
|
||||||
return c.Marshal(s)
|
return c.Marshal(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||||
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
||||||
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
|
func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
|
||||||
s := &SessionState{}
|
s := &SessionState{}
|
||||||
err := c.Unmarshal(value, s)
|
err := c.Unmarshal(value, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -5,12 +5,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSessionStateSerialization(t *testing.T) {
|
func TestSessionStateSerialization(t *testing.T) {
|
||||||
secret := aead.GenerateKey()
|
secret := cryptutil.GenerateKey()
|
||||||
c, err := aead.New([]byte(secret))
|
c, err := cryptutil.NewCipher([]byte(secret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,6 @@ func NewAuthenticateClient(uri *url.URL, sharedKey string, sessionValid, session
|
||||||
return &AuthenticateClient{
|
return &AuthenticateClient{
|
||||||
AuthenticateServiceURL: uri,
|
AuthenticateServiceURL: uri,
|
||||||
|
|
||||||
// ClientID: clientID,
|
|
||||||
SharedKey: sharedKey,
|
SharedKey: sharedKey,
|
||||||
|
|
||||||
SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}),
|
SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}),
|
||||||
|
@ -258,7 +257,6 @@ func (p *AuthenticateClient) ValidateSessionState(s *sessions.SessionState) bool
|
||||||
// authentication, and is merely unavailable, we validate and continue
|
// authentication, and is merely unavailable, we validate and continue
|
||||||
// as normal during the "grace period"
|
// as normal during the "grace period"
|
||||||
if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) {
|
if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) {
|
||||||
//tags := []string{"action:validate_session", "error:validation_failed"}
|
|
||||||
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
|
s.ValidDeadline = extendDeadline(p.SessionValidTTL)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
@ -167,7 +167,7 @@ func (p *Proxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
callbackURL := p.GetRedirectURL(req.Host)
|
callbackURL := p.GetRedirectURL(req.Host)
|
||||||
|
|
||||||
// generate nonce
|
// generate nonce
|
||||||
key := aead.GenerateKey()
|
key := cryptutil.GenerateKey()
|
||||||
|
|
||||||
// state prevents cross site forgery and maintain state across the client and server
|
// state prevents cross site forgery and maintain state across the client and server
|
||||||
state := &StateParameter{
|
state := &StateParameter{
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/envconfig"
|
"github.com/pomerium/envconfig"
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
@ -117,7 +117,7 @@ type Proxy struct {
|
||||||
// services
|
// services
|
||||||
authenticateClient *authenticator.AuthenticateClient
|
authenticateClient *authenticator.AuthenticateClient
|
||||||
// session
|
// session
|
||||||
cipher aead.Cipher
|
cipher cryptutil.Cipher
|
||||||
csrfStore sessions.CSRFStore
|
csrfStore sessions.CSRFStore
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ func NewProxy(opts *Options) (*Proxy, error) {
|
||||||
|
|
||||||
// error explicitly handled by validate
|
// error explicitly handled by validate
|
||||||
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
cipher, err := aead.New(decodedSecret)
|
cipher, err := cryptutil.NewCipher(decodedSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
|
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue