mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-15 17:22:56 +02:00
urlutil: add time validation functions (#3776)
This commit is contained in:
parent
457fca08dc
commit
090601873f
3 changed files with 111 additions and 1 deletions
|
@ -8,12 +8,16 @@ const (
|
||||||
QueryDeviceCredentialID = "pomerium_device_credential_id"
|
QueryDeviceCredentialID = "pomerium_device_credential_id"
|
||||||
QueryDeviceType = "pomerium_device_type"
|
QueryDeviceType = "pomerium_device_type"
|
||||||
QueryEnrollmentToken = "pomerium_enrollment_token" //nolint
|
QueryEnrollmentToken = "pomerium_enrollment_token" //nolint
|
||||||
|
QueryExpiry = "pomerium_expiry"
|
||||||
|
QueryIdentityProfile = "pomerium_identity_profile"
|
||||||
QueryIdentityProviderID = "pomerium_idp_id"
|
QueryIdentityProviderID = "pomerium_idp_id"
|
||||||
QueryIsProgrammatic = "pomerium_programmatic"
|
QueryIsProgrammatic = "pomerium_programmatic"
|
||||||
|
QueryIssued = "pomerium_issued"
|
||||||
QueryPomeriumJWT = "pomerium_jwt"
|
QueryPomeriumJWT = "pomerium_jwt"
|
||||||
|
QueryRedirectURI = "pomerium_redirect_uri"
|
||||||
QuerySession = "pomerium_session"
|
QuerySession = "pomerium_session"
|
||||||
QuerySessionEncrypted = "pomerium_session_encrypted"
|
QuerySessionEncrypted = "pomerium_session_encrypted"
|
||||||
QueryRedirectURI = "pomerium_redirect_uri"
|
QuerySessionState = "pomerium_session_state"
|
||||||
)
|
)
|
||||||
|
|
||||||
// URL signature based query params used for verifying the authenticity of a URL.
|
// URL signature based query params used for verifying the authenticity of a URL.
|
||||||
|
|
43
internal/urlutil/time.go
Normal file
43
internal/urlutil/time.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package urlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildTimeParameters adds the issued and expiry timestamps to the query parameters.
|
||||||
|
func BuildTimeParameters(params url.Values, expiry time.Duration) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
params.Set(QueryIssued, fmt.Sprint(now.UnixMilli()))
|
||||||
|
params.Set(QueryExpiry, fmt.Sprint(now.Add(expiry).UnixMilli()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTimeParameters validates that the issued and expiry timestamps in the query parameters are valid.
|
||||||
|
func ValidateTimeParameters(params url.Values) error {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
issuedMS, err := strconv.ParseInt(params.Get(QueryIssued), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid issued timestamp: %w", err)
|
||||||
|
}
|
||||||
|
issued := time.UnixMilli(issuedMS)
|
||||||
|
|
||||||
|
if now.Add(DefaultLeeway).Before(issued) {
|
||||||
|
return ErrIssuedInTheFuture
|
||||||
|
}
|
||||||
|
|
||||||
|
expiryMS, err := strconv.ParseInt(params.Get(QueryExpiry), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid expiry timestamp: %w", err)
|
||||||
|
}
|
||||||
|
expiry := time.UnixMilli(expiryMS)
|
||||||
|
|
||||||
|
if now.Add(-DefaultLeeway).After(expiry) {
|
||||||
|
return ErrExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
63
internal/urlutil/time_test.go
Normal file
63
internal/urlutil/time_test.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
package urlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildTimeParameters(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
params := make(url.Values)
|
||||||
|
BuildTimeParameters(params, time.Minute)
|
||||||
|
assert.True(t, params.Has(QueryIssued))
|
||||||
|
assert.True(t, params.Has(QueryExpiry))
|
||||||
|
|
||||||
|
ms1, _ := strconv.Atoi(params.Get(QueryIssued))
|
||||||
|
ms2, _ := strconv.Atoi(params.Get(QueryExpiry))
|
||||||
|
assert.Equal(t, 60000, ms2-ms1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTimeParameters(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
msNow := time.Now().UnixMilli()
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
params url.Values
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"empty", url.Values{}, "invalid issued timestamp"},
|
||||||
|
{"missing issued", url.Values{QueryExpiry: {fmt.Sprint(msNow + 10000)}}, "invalid issued timestamp"},
|
||||||
|
{"missing expiry", url.Values{QueryIssued: {fmt.Sprint(msNow + 10000)}}, "invalid expiry timestamp"},
|
||||||
|
{"invalid issued", url.Values{
|
||||||
|
QueryIssued: {fmt.Sprint(msNow + 120000)},
|
||||||
|
QueryExpiry: {fmt.Sprint(msNow + 240000)},
|
||||||
|
}, "issued in the future"},
|
||||||
|
{"invalid expiry", url.Values{
|
||||||
|
QueryIssued: {fmt.Sprint(msNow - 120000)},
|
||||||
|
QueryExpiry: {fmt.Sprint(msNow - 240000)},
|
||||||
|
}, "expired"},
|
||||||
|
{"valid", url.Values{
|
||||||
|
QueryIssued: {fmt.Sprint(msNow)},
|
||||||
|
QueryExpiry: {fmt.Sprint(msNow)},
|
||||||
|
}, ""},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := ValidateTimeParameters(tc.params)
|
||||||
|
if tc.err == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue