mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
cryptutil: use bytes for hmac (#2067)
This commit is contained in:
parent
a935c1ba30
commit
a51c7140ea
12 changed files with 28 additions and 28 deletions
|
@ -55,7 +55,7 @@ func (h *Handler) GetPolicyIDFromHeaders(headers http.Header) (uint64, bool) {
|
|||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
return policyID, cryptutil.CheckHMAC([]byte(policyStr), hmac, string(h.key))
|
||||
return policyID, cryptutil.CheckHMAC([]byte(policyStr), hmac, h.key)
|
||||
}
|
||||
|
||||
// GetPolicyIDHeaders returns http headers for the given policy id.
|
||||
|
@ -64,7 +64,7 @@ func (h *Handler) GetPolicyIDHeaders(policyID uint64) [][2]string {
|
|||
defer h.mu.RUnlock()
|
||||
|
||||
s := strconv.FormatUint(policyID, 10)
|
||||
hmac := base64.StdEncoding.EncodeToString(cryptutil.GenerateHMAC([]byte(s), string(h.key)))
|
||||
hmac := base64.StdEncoding.EncodeToString(cryptutil.GenerateHMAC([]byte(s), h.key))
|
||||
return [][2]string{
|
||||
{httputil.HeaderPomeriumReproxyPolicy, s},
|
||||
{httputil.HeaderPomeriumReproxyPolicyHMAC, hmac},
|
||||
|
|
|
@ -28,7 +28,7 @@ func SetHeaders(headers map[string]string) func(next http.Handler) http.Handler
|
|||
|
||||
// ValidateSignature ensures the request is valid and has been signed with
|
||||
// the correspdoning client secret key
|
||||
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
func ValidateSignature(sharedSecret []byte) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
|
||||
|
@ -44,7 +44,7 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
|||
|
||||
// ValidateRequestURL validates the current absolute request URL was signed
|
||||
// by a given shared key.
|
||||
func ValidateRequestURL(r *http.Request, key string) error {
|
||||
func ValidateRequestURL(r *http.Request, key []byte) error {
|
||||
return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
|
||||
}
|
||||
|
||||
|
|
|
@ -125,13 +125,13 @@ func TestValidateSignature(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
secretA string
|
||||
secretB string
|
||||
secretA []byte
|
||||
secretB []byte
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", "secret", "secret", http.StatusOK, http.StatusText(http.StatusOK)},
|
||||
{"secret mistmatch", "secret", "hunter42", http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: internal/urlutil: hmac failed\"}\n"},
|
||||
{"good", []byte("secret"), []byte("secret"), http.StatusOK, http.StatusText(http.StatusOK)},
|
||||
{"secret mistmatch", []byte("secret"), []byte("hunter42"), http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: internal/urlutil: hmac failed\"}\n"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
// SignedURL is a shared-key HMAC wrapped URL.
|
||||
type SignedURL struct {
|
||||
uri url.URL
|
||||
key string
|
||||
key []byte
|
||||
signed bool
|
||||
|
||||
// mockable time for testing
|
||||
|
@ -24,7 +24,7 @@ type SignedURL struct {
|
|||
//
|
||||
// N.B. It is the user's responsibility to make sure the key is 256 bits and
|
||||
// the url is not nil.
|
||||
func NewSignedURL(key string, uri *url.URL) *SignedURL {
|
||||
func NewSignedURL(key []byte, uri *url.URL) *SignedURL {
|
||||
return &SignedURL{uri: *uri, key: key, timeNow: time.Now} // uri is copied
|
||||
}
|
||||
|
||||
|
@ -93,7 +93,7 @@ func (su *SignedURL) Validate() error {
|
|||
|
||||
// hmacURL takes a redirect url string and timestamp and returns the base64
|
||||
// encoded HMAC result.
|
||||
func hmacURL(key string, data ...interface{}) string {
|
||||
func hmacURL(key []byte, data ...interface{}) string {
|
||||
h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data...)), key)
|
||||
return base64.URLEncoding.EncodeToString(h)
|
||||
}
|
||||
|
|
|
@ -31,14 +31,14 @@ func TestSignedURL(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signedURL := NewSignedURL(tt.key, &tt.uri)
|
||||
signedURL := NewSignedURL([]byte(tt.key), &tt.uri)
|
||||
signedURL.timeNow = tt.origTime
|
||||
|
||||
if diff := cmp.Diff(signedURL.String(), tt.wantStr); diff != "" {
|
||||
t.Errorf("signedURL() = %v", diff)
|
||||
}
|
||||
|
||||
signedURL = NewSignedURL(tt.key, &tt.uri)
|
||||
signedURL = NewSignedURL([]byte(tt.key), &tt.uri)
|
||||
signedURL.timeNow = tt.origTime
|
||||
got := signedURL.Sign()
|
||||
|
||||
|
@ -89,7 +89,7 @@ func TestSignedURL_Validate(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
out := NewSignedURL(tt.key, &tt.uri)
|
||||
out := NewSignedURL([]byte(tt.key), &tt.uri)
|
||||
out.timeNow = tt.timeNow
|
||||
|
||||
if err := out.Validate(); (err != nil) != tt.wantErr {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue