internal/cryputil: combines aead and cryptutil packages.

- Refactored encrypt / decrypt methods to use aead's NonceSize() interface method.
- Add explicit GenerateKey function.
- Remove mutex on XChaCha20.
This commit is contained in:
Bobby DeSimone 2019-01-18 11:55:04 -08:00
parent 131810ccfe
commit 24b11b0428
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
11 changed files with 44 additions and 89 deletions

View file

@ -12,7 +12,7 @@ import (
"github.com/pomerium/envconfig" "github.com/pomerium/envconfig"
"github.com/pomerium/pomerium/authenticate/providers" "github.com/pomerium/pomerium/authenticate/providers"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
) )
@ -132,7 +132,7 @@ type Authenticator struct {
// sesion related // sesion related
csrfStore sessions.CSRFStore csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
cipher aead.Cipher cipher cryptutil.Cipher
provider providers.Provider provider providers.Provider
} }
@ -149,7 +149,7 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cipher, err := aead.New([]byte(decodedAuthCodeSecret)) cipher, err := cryptutil.NewCipher([]byte(decodedAuthCodeSecret))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -9,7 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
m "github.com/pomerium/pomerium/internal/middleware" m "github.com/pomerium/pomerium/internal/middleware"
@ -339,7 +339,7 @@ func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, m
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication. // `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) { func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
nonce := fmt.Sprintf("%x", aead.GenerateKey()) nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
p.csrfStore.SetCSRF(rw, req, nonce) p.csrfStore.SetCSRF(rw, req, nonce)
authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri")) authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))

View file

@ -1,34 +0,0 @@
package aead // import "github.com/pomerium/pomerium/internal/aead"
import (
"encoding/json"
)
// MockCipher is a mock of the cipher interface
type MockCipher struct {
MarshalError error
MarshalString string
UnmarshalError error
UnmarshalBytes []byte
}
// Encrypt returns an empty byte array and nil
func (mc *MockCipher) Encrypt([]byte) ([]byte, error) {
return []byte{}, nil
}
// Decrypt returns an empty byte array and nil
func (mc *MockCipher) Decrypt([]byte) ([]byte, error) {
return []byte{}, nil
}
// Marshal returns the marshal string and marsha error
func (mc *MockCipher) Marshal(interface{}) (string, error) {
return mc.MarshalString, mc.MarshalError
}
// Unmarshal unmarshals the unmarshal bytes to be set in s and returns the unmarshal error
func (mc *MockCipher) Unmarshal(b string, s interface{}) error {
json.Unmarshal(mc.UnmarshalBytes, s)
return mc.UnmarshalError
}

View file

@ -1,4 +1,4 @@
package aead // import "github.com/pomerium/pomerium/internal/aead" package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import ( import (
"bytes" "bytes"
@ -9,11 +9,20 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"sync"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
// GenerateKey generates a random 32-byte key.
// Panics if source of randomness fails.
func GenerateKey() []byte {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
panic(err)
}
return key
}
// Cipher provides methods to encrypt and decrypt values. // Cipher provides methods to encrypt and decrypt values.
type Cipher interface { type Cipher interface {
Encrypt([]byte) ([]byte, error) Encrypt([]byte) ([]byte, error)
@ -27,12 +36,10 @@ type Cipher interface {
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption // For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
type XChaCha20Cipher struct { type XChaCha20Cipher struct {
aead cipher.AEAD aead cipher.AEAD
mu sync.Mutex
} }
// New returns a new AES Cipher for encrypting values // NewCipher returns a new XChacha20poly1305 cipher.
func New(secret []byte) (*XChaCha20Cipher, error) { func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
aead, err := chacha20poly1305.NewX(secret) aead, err := chacha20poly1305.NewX(secret)
if err != nil { if err != nil {
return nil, err return nil, err
@ -42,20 +49,10 @@ func New(secret []byte) (*XChaCha20Cipher, error) {
}, nil }, nil
} }
// GenerateKey generates a random 32-byte encryption key. // GenerateNonce generates a random nonce.
// Panics if the key size is unsupported or source of randomness fails. // Panics if source of randomness fails.
func GenerateKey() []byte { func (c *XChaCha20Cipher) GenerateNonce() []byte {
nonce := make([]byte, chacha20poly1305.KeySize) nonce := make([]byte, c.aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
panic(err)
}
return nonce
}
// GenerateNonce generates a random 24-byte nonce for XChaCha20-Poly1305.
// Panics if the key size is unsupported or source of randomness fails.
func GenerateNonce() []byte {
nonce := make([]byte, chacha20poly1305.NonceSizeX)
if _, err := rand.Read(nonce); err != nil { if _, err := rand.Read(nonce); err != nil {
panic(err) panic(err)
} }
@ -64,15 +61,12 @@ func GenerateNonce() []byte {
// Encrypt a value using XChaCha20-Poly1305 // Encrypt a value using XChaCha20-Poly1305
func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) { func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) {
c.mu.Lock()
defer c.mu.Unlock()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err = fmt.Errorf("internal/aead: error encrypting bytes: %v", r) err = fmt.Errorf("internal/aead: error encrypting bytes: %v", r)
} }
}() }()
nonce := GenerateNonce() nonce := c.GenerateNonce()
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil) ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
@ -83,14 +77,11 @@ func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) {
// Decrypt a value using XChaCha20-Poly1305 // Decrypt a value using XChaCha20-Poly1305
func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) { func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) {
c.mu.Lock() if len(joined) <= c.aead.NonceSize() {
defer c.mu.Unlock()
if len(joined) <= chacha20poly1305.NonceSizeX {
return nil, fmt.Errorf("internal/aead: invalid input size: %d", len(joined)) return nil, fmt.Errorf("internal/aead: invalid input size: %d", len(joined))
} }
// grab out the nonce // grab out the nonce
pivot := len(joined) - chacha20poly1305.NonceSizeX pivot := len(joined) - c.aead.NonceSize()
ciphertext := joined[:pivot] ciphertext := joined[:pivot]
nonce := joined[pivot:] nonce := joined[pivot:]

View file

@ -1,4 +1,4 @@
package aead // import "github.com/pomerium/pomerium/internal/aead" package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import ( import (
"crypto/rand" "crypto/rand"
@ -13,7 +13,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
plaintext := []byte("my plain text value") plaintext := []byte("my plain text value")
key := GenerateKey() key := GenerateKey()
c, err := New([]byte(key)) c, err := NewCipher([]byte(key))
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -42,7 +42,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
func TestMarshalAndUnmarshalStruct(t *testing.T) { func TestMarshalAndUnmarshalStruct(t *testing.T) {
key := GenerateKey() key := GenerateKey()
c, err := New([]byte(key)) c, err := NewCipher([]byte(key))
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -95,7 +95,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
} }
func TestCipherDataRace(t *testing.T) { func TestCipherDataRace(t *testing.T) {
miscreantCipher, err := New(GenerateKey()) cipher, err := NewCipher(GenerateKey())
if err != nil { if err != nil {
t.Fatalf("unexpected generating cipher err: %v", err) t.Fatalf("unexpected generating cipher err: %v", err)
} }
@ -158,7 +158,7 @@ func TestCipherDataRace(t *testing.T) {
t.Fatalf("expected structs to be equal") t.Fatalf("expected structs to be equal")
} }
}(miscreantCipher, wg) }(cipher, wg)
} }
wg.Wait() wg.Wait()
} }

View file

@ -7,7 +7,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
) )
// ErrInvalidSession is an error for invalid sessions. // ErrInvalidSession is an error for invalid sessions.
@ -36,14 +36,14 @@ type CookieStore struct {
CookieSecure bool CookieSecure bool
CookieHTTPOnly bool CookieHTTPOnly bool
CookieDomain string CookieDomain string
CookieCipher aead.Cipher CookieCipher cryptutil.Cipher
SessionLifetimeTTL time.Duration SessionLifetimeTTL time.Duration
} }
// CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret // CreateMiscreantCookieCipher creates a new miscreant cipher with the cookie secret
func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error { func CreateMiscreantCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
return func(s *CookieStore) error { return func(s *CookieStore) error {
cipher, err := aead.New(cookieSecret) cipher, err := cryptutil.NewCipher(cookieSecret)
if err != nil { if err != nil {
return fmt.Errorf("miscreant cookie-secret error: %s", err.Error()) return fmt.Errorf("miscreant cookie-secret error: %s", err.Error())
} }

View file

@ -4,7 +4,7 @@ import (
"errors" "errors"
"time" "time"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
) )
var ( var (
@ -48,13 +48,13 @@ func isExpired(t time.Time) bool {
// MarshalSession marshals the session state as JSON, encrypts the JSON using the // MarshalSession marshals the session state as JSON, encrypts the JSON using the
// given cipher, and base64-encodes the result // given cipher, and base64-encodes the result
func MarshalSession(s *SessionState, c aead.Cipher) (string, error) { func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) {
return c.Marshal(s) return c.Marshal(s)
} }
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) { func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
s := &SessionState{} s := &SessionState{}
err := c.Unmarshal(value, s) err := c.Unmarshal(value, s)
if err != nil { if err != nil {

View file

@ -5,12 +5,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
) )
func TestSessionStateSerialization(t *testing.T) { func TestSessionStateSerialization(t *testing.T) {
secret := aead.GenerateKey() secret := cryptutil.GenerateKey()
c, err := aead.New([]byte(secret)) c, err := cryptutil.NewCipher([]byte(secret))
if err != nil { if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err) t.Fatalf("expected to be able to create cipher: %v", err)
} }

View file

@ -61,7 +61,6 @@ func NewAuthenticateClient(uri *url.URL, sharedKey string, sessionValid, session
return &AuthenticateClient{ return &AuthenticateClient{
AuthenticateServiceURL: uri, AuthenticateServiceURL: uri,
// ClientID: clientID,
SharedKey: sharedKey, SharedKey: sharedKey,
SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}), SignInURL: uri.ResolveReference(&url.URL{Path: "/sign_in"}),
@ -258,7 +257,6 @@ func (p *AuthenticateClient) ValidateSessionState(s *sessions.SessionState) bool
// authentication, and is merely unavailable, we validate and continue // authentication, and is merely unavailable, we validate and continue
// as normal during the "grace period" // as normal during the "grace period"
if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) { if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) {
//tags := []string{"action:validate_session", "error:validation_failed"}
s.ValidDeadline = extendDeadline(p.SessionValidTTL) s.ValidDeadline = extendDeadline(p.SessionValidTTL)
return true return true
} }

View file

@ -8,7 +8,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
@ -167,7 +167,7 @@ func (p *Proxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
callbackURL := p.GetRedirectURL(req.Host) callbackURL := p.GetRedirectURL(req.Host)
// generate nonce // generate nonce
key := aead.GenerateKey() key := cryptutil.GenerateKey()
// state prevents cross site forgery and maintain state across the client and server // state prevents cross site forgery and maintain state across the client and server
state := &StateParameter{ state := &StateParameter{

View file

@ -13,7 +13,7 @@ import (
"time" "time"
"github.com/pomerium/envconfig" "github.com/pomerium/envconfig"
"github.com/pomerium/pomerium/internal/aead" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
@ -117,7 +117,7 @@ type Proxy struct {
// services // services
authenticateClient *authenticator.AuthenticateClient authenticateClient *authenticator.AuthenticateClient
// session // session
cipher aead.Cipher cipher cryptutil.Cipher
csrfStore sessions.CSRFStore csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
@ -144,7 +144,7 @@ func NewProxy(opts *Options) (*Proxy, error) {
// error explicitly handled by validate // error explicitly handled by validate
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, err := aead.New(decodedSecret) cipher, err := cryptutil.NewCipher(decodedSecret)
if err != nil { if err != nil {
return nil, fmt.Errorf("cookie-secret error: %s", err.Error()) return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
} }