mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
73 lines
1.8 KiB
Go
73 lines
1.8 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestLocalCache(t *testing.T) {
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer clearTimeout()
|
|
|
|
callCount := 0
|
|
update := func(ctx context.Context) ([]byte, error) {
|
|
callCount++
|
|
return []byte("v1"), nil
|
|
}
|
|
c := NewLocalCache()
|
|
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("v1"), v)
|
|
assert.Equal(t, 1, callCount)
|
|
|
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("v1"), v)
|
|
assert.Equal(t, 1, callCount)
|
|
|
|
c.Invalidate([]byte("k1"))
|
|
|
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("v1"), v)
|
|
assert.Equal(t, 2, callCount)
|
|
}
|
|
|
|
func TestGlobalCache(t *testing.T) {
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer clearTimeout()
|
|
|
|
callCount := 0
|
|
update := func(ctx context.Context) ([]byte, error) {
|
|
callCount++
|
|
return []byte("v1"), nil
|
|
}
|
|
c := NewGlobalCache(time.Millisecond * 100)
|
|
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("v1"), v)
|
|
assert.Equal(t, 1, callCount)
|
|
|
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("v1"), v)
|
|
assert.Equal(t, 1, callCount)
|
|
|
|
c.Invalidate([]byte("k1"))
|
|
|
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("v1"), v)
|
|
assert.Equal(t, 2, callCount)
|
|
|
|
assert.Eventually(t, func() bool {
|
|
_, err := c.GetOrUpdate(ctx, []byte("k1"), func(ctx context.Context) ([]byte, error) {
|
|
return nil, fmt.Errorf("ERROR")
|
|
})
|
|
return err != nil
|
|
}, time.Second, time.Millisecond*10, "should honor TTL")
|
|
}
|