mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
cryptutil: add SecureToken (#2681)
* cryptutil: add SecureToken * add parse
This commit is contained in:
parent
4e4a161521
commit
0f0a5dc7f0
2 changed files with 147 additions and 0 deletions
|
@ -1,6 +1,12 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcutil/base58"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
@ -66,3 +72,96 @@ func (tok SecretToken) String() string {
|
|||
copy(bs[TokenLength:], tok.Secret[:])
|
||||
return base58.Encode(bs)
|
||||
}
|
||||
|
||||
// errors related to the SecureToken
|
||||
var (
|
||||
ErrExpired = errors.New("expired")
|
||||
ErrInvalid = errors.New("invalid")
|
||||
)
|
||||
|
||||
const (
|
||||
// SecureTokenTimeLength is the length of the time part of the SecureToken.
|
||||
SecureTokenTimeLength = 8
|
||||
// SecureTokenHMACLength is the length of the HMAC part of the SecureToken.
|
||||
SecureTokenHMACLength = 32
|
||||
// SecureTokenLength is the byte length of a SecureToken.
|
||||
SecureTokenLength = TokenLength + SecureTokenTimeLength + SecureTokenHMACLength
|
||||
)
|
||||
|
||||
// A SecureToken is an HMAC'd Token with an expiration time.
|
||||
type SecureToken [SecureTokenLength]byte
|
||||
|
||||
// GenerateSecureToken generates a SecureToken from the given key, expiry and token.
|
||||
func GenerateSecureToken(key []byte, expiry time.Time, token Token) SecureToken {
|
||||
var secureToken SecureToken
|
||||
copy(secureToken[:], token[:])
|
||||
binary.BigEndian.PutUint64(secureToken[TokenLength:], uint64(expiry.UnixMilli()))
|
||||
h := secureToken.computeHMAC(key)
|
||||
copy(secureToken[TokenLength+SecureTokenTimeLength:], h[:])
|
||||
return secureToken
|
||||
}
|
||||
|
||||
// SecureTokenFromString parses a base58-encoded string into a SecureToken.
|
||||
func SecureTokenFromString(rawstr string) (secureToken SecureToken, ok bool) {
|
||||
result := base58.Decode(rawstr)
|
||||
if len(result) != SecureTokenLength {
|
||||
return secureToken, false
|
||||
}
|
||||
copy(secureToken[:], result[:SecureTokenLength])
|
||||
return secureToken, true
|
||||
}
|
||||
|
||||
// Bytes returns the secret token as bytes.
|
||||
func (secureToken SecureToken) Bytes() []byte {
|
||||
return secureToken[:]
|
||||
}
|
||||
|
||||
// Expiry returns the SecureToken expiration time.
|
||||
func (secureToken SecureToken) Expiry() time.Time {
|
||||
return time.UnixMilli(int64(binary.BigEndian.Uint64(secureToken[TokenLength:])))
|
||||
}
|
||||
|
||||
// HMAC returns the HMAC part of the SecureToken.
|
||||
func (secureToken SecureToken) HMAC() [SecureTokenHMACLength]byte {
|
||||
var result [SecureTokenHMACLength]byte
|
||||
copy(result[:], secureToken[TokenLength+SecureTokenTimeLength:])
|
||||
return result
|
||||
}
|
||||
|
||||
// String returns the SecureToken as a string.
|
||||
func (secureToken SecureToken) String() string {
|
||||
return base58.Encode(secureToken[:])
|
||||
}
|
||||
|
||||
// Token returns the Token part of the SecureToken.
|
||||
func (secureToken SecureToken) Token() Token {
|
||||
var result Token
|
||||
copy(result[:], secureToken[:])
|
||||
return result
|
||||
}
|
||||
|
||||
// Verify verifies that the SecureToken has a valid HMAC and hasn't expired.
|
||||
func (secureToken SecureToken) Verify(key []byte, now time.Time) error {
|
||||
if !secureToken.checkHMAC(key) {
|
||||
return ErrInvalid
|
||||
}
|
||||
|
||||
if secureToken.Expiry().Before(now) {
|
||||
return ErrExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (secureToken SecureToken) checkHMAC(key []byte) bool {
|
||||
expectedHMAC := secureToken.computeHMAC(key)
|
||||
actualHMAC := secureToken.HMAC()
|
||||
return hmac.Equal(actualHMAC[:], expectedHMAC[:])
|
||||
}
|
||||
|
||||
func (secureToken SecureToken) computeHMAC(key []byte) (result [SecureTokenHMACLength]byte) {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(secureToken[:TokenLength+SecureTokenTimeLength])
|
||||
copy(result[:], h.Sum(nil))
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package cryptutil
|
|||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -62,3 +63,50 @@ func TestSecretTokenFromString(t *testing.T) {
|
|||
assert.Equal(t, "fb297629-e61f-4f1d-bb7e-ece3ed702098", tok.ID.UUID().String())
|
||||
assert.Equal(t, "047fb3ad-b1c7-463b-b16c-e41836811cc2", tok.Secret.UUID().String())
|
||||
}
|
||||
|
||||
func TestSecureToken(t *testing.T) {
|
||||
key := []byte{1, 2, 3, 4, 5}
|
||||
expiry := time.Date(2021, 10, 14, 12, 27, 0, 0, time.UTC)
|
||||
token := Token(uuid.MustParse("38ad02ee-5db4-4246-9d4c-44e4a0077408"))
|
||||
secureToken := GenerateSecureToken(key, expiry, token)
|
||||
assert.Equal(t, "2Y2GNugUpcunes9epx9ehkdHwJvejtnBzNJ5iniiRYv3rMoE7LMN3tZmf7ZGNidJKMSvTtCYEqtE5", secureToken.String())
|
||||
assert.Equal(t, []byte{
|
||||
0x38, 0xad, 0x02, 0xee, 0x5d, 0xb4, 0x42, 0x46,
|
||||
0x9d, 0x4c, 0x44, 0xe4, 0xa0, 0x07, 0x74, 0x08,
|
||||
0x00, 0x00, 0x01, 0x7c, 0x7e, 0xc5, 0x1e, 0x20,
|
||||
0x39, 0xc5, 0xca, 0x5a, 0x77, 0xc4, 0xbc, 0x65,
|
||||
0x56, 0x22, 0x0b, 0x17, 0x7a, 0xae, 0x97, 0x4c,
|
||||
0xa9, 0x6a, 0x99, 0x69, 0x9e, 0xce, 0x20, 0xbd,
|
||||
0xd6, 0xba, 0xb9, 0x3c, 0x16, 0x30, 0x6d, 0x12,
|
||||
}, secureToken.Bytes())
|
||||
assert.Equal(t, [SecureTokenHMACLength]byte{
|
||||
0x39, 0xc5, 0xca, 0x5a, 0x77, 0xc4, 0xbc, 0x65,
|
||||
0x56, 0x22, 0x0b, 0x17, 0x7a, 0xae, 0x97, 0x4c,
|
||||
0xa9, 0x6a, 0x99, 0x69, 0x9e, 0xce, 0x20, 0xbd,
|
||||
0xd6, 0xba, 0xb9, 0x3c, 0x16, 0x30, 0x6d, 0x12,
|
||||
}, secureToken.HMAC())
|
||||
assert.Equal(t, Token{
|
||||
0x38, 0xad, 0x02, 0xee, 0x5d, 0xb4, 0x42, 0x46,
|
||||
0x9d, 0x4c, 0x44, 0xe4, 0xa0, 0x07, 0x74, 0x08,
|
||||
}, secureToken.Token())
|
||||
assert.Equal(t, expiry, secureToken.Expiry().UTC())
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
parsed, ok := SecureTokenFromString("2Y2GNugUpcunes9epx9ehkdHwJvejtnBzNJ5iniiRYv3rMoE7LMN3tZmf7ZGNidJKMSvTtCYEqtE5")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, secureToken, parsed)
|
||||
})
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
err := secureToken.Verify(key, expiry.Add(-time.Second))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
err := secureToken.Verify([]byte{6, 7, 8, 9, 0}, expiry.Add(time.Second))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("expired", func(t *testing.T) {
|
||||
err := secureToken.Verify(key, expiry.Add(time.Second))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue