pomerium/internal/zero/token/cache_test.go
Denis Mishin ce12e51cf5
zero/api: reset token and url cache if 401 is received (#5256)
zero/api: reset token cache if 401 is received
2024-09-03 15:40:28 -04:00

79 lines
1.9 KiB
Go

package token_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/zero/token"
)
func TestCache(t *testing.T) {
t.Parallel()
t.Run("token expired, fetch new", func(t *testing.T) {
t.Parallel()
var testToken *token.Token
var testError error
fetcher := func(_ context.Context, _ string) (*token.Token, error) {
if testToken != nil {
token := *testToken
return &token, nil
}
return nil, testError
}
c := token.NewCache(fetcher, "test-refresh-token")
now := time.Now()
c.TimeNow = func() time.Time { return now }
testToken = &token.Token{"bearer-1", now.Add(time.Hour)}
bearer, err := c.GetToken(context.Background(), time.Minute)
require.NoError(t, err)
assert.Equal(t, "bearer-1", bearer)
now = now.Add(time.Minute * 30)
testToken.Bearer = "bearer-2"
// token is still valid, so we should get the same one
bearer, err = c.GetToken(context.Background(), time.Minute*20)
require.NoError(t, err)
assert.Equal(t, "bearer-1", bearer)
now = now.Add(time.Minute * 30)
testToken = &token.Token{"bearer-3", now.Add(time.Hour)}
bearer, err = c.GetToken(context.Background(), time.Minute*30)
require.NoError(t, err)
assert.Equal(t, "bearer-3", bearer)
})
t.Run("reset", func(t *testing.T) {
t.Parallel()
var calls int
fetcher := func(_ context.Context, _ string) (*token.Token, error) {
calls++
return &token.Token{fmt.Sprintf("bearer-%d", calls), time.Now().Add(time.Hour)}, nil
}
ctx := context.Background()
c := token.NewCache(fetcher, "test-refresh-token")
got, err := c.GetToken(ctx, time.Minute*2)
assert.NoError(t, err)
assert.Equal(t, "bearer-1", got)
got, err = c.GetToken(ctx, time.Minute*2)
assert.NoError(t, err)
assert.Equal(t, "bearer-1", got)
c.Reset()
got, err = c.GetToken(ctx, time.Minute*2)
assert.NoError(t, err)
assert.Equal(t, "bearer-2", got)
})
}