package token_test import ( "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/zero/token" ) func TestCache(t *testing.T) { t.Parallel() t.Run("token expired, fetch new", func(t *testing.T) { t.Parallel() var testToken *token.Token var testError error fetcher := func(_ context.Context, _ string) (*token.Token, error) { if testToken != nil { token := *testToken return &token, nil } return nil, testError } c := token.NewCache(fetcher, "test-refresh-token") now := time.Now() c.TimeNow = func() time.Time { return now } testToken = &token.Token{"bearer-1", now.Add(time.Hour)} bearer, err := c.GetToken(context.Background(), time.Minute) require.NoError(t, err) assert.Equal(t, "bearer-1", bearer) now = now.Add(time.Minute * 30) testToken.Bearer = "bearer-2" // token is still valid, so we should get the same one bearer, err = c.GetToken(context.Background(), time.Minute*20) require.NoError(t, err) assert.Equal(t, "bearer-1", bearer) now = now.Add(time.Minute * 30) testToken = &token.Token{"bearer-3", now.Add(time.Hour)} bearer, err = c.GetToken(context.Background(), time.Minute*30) require.NoError(t, err) assert.Equal(t, "bearer-3", bearer) }) t.Run("reset", func(t *testing.T) { t.Parallel() var calls int fetcher := func(_ context.Context, _ string) (*token.Token, error) { calls++ return &token.Token{fmt.Sprintf("bearer-%d", calls), time.Now().Add(time.Hour)}, nil } ctx := context.Background() c := token.NewCache(fetcher, "test-refresh-token") got, err := c.GetToken(ctx, time.Minute*2) assert.NoError(t, err) assert.Equal(t, "bearer-1", got) got, err = c.GetToken(ctx, time.Minute*2) assert.NoError(t, err) assert.Equal(t, "bearer-1", got) c.Reset() got, err = c.GetToken(ctx, time.Minute*2) assert.NoError(t, err) assert.Equal(t, "bearer-2", got) }) }