pomerium/pkg/storage/cache.go

155 lines
3.1 KiB
Go

package storage
import (
"context"
"encoding/binary"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache"
"golang.org/x/sync/singleflight"
)
// A Cache will return cached data when available or call update when not.
type Cache interface {
GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error)
Invalidate(key []byte)
}
type localCache struct {
singleflight singleflight.Group
mu sync.RWMutex
m map[string][]byte
}
// NewLocalCache creates a new Cache backed by a map.
func NewLocalCache() Cache {
return &localCache{
m: make(map[string][]byte),
}
}
func (cache *localCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
strkey := string(key)
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
v, err, _ := cache.singleflight.Do(strkey, func() (interface{}, error) {
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
result, err := update(ctx)
if err != nil {
return nil, err
}
cache.mu.Lock()
cache.m[strkey] = result
cache.mu.Unlock()
return result, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *localCache) Invalidate(key []byte) {
cache.mu.Lock()
delete(cache.m, string(key))
cache.mu.Unlock()
}
type globalCache struct {
ttl time.Duration
singleflight singleflight.Group
mu sync.RWMutex
fastcache *fastcache.Cache
}
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
func NewGlobalCache(ttl time.Duration) Cache {
return &globalCache{
ttl: ttl,
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
}
}
func (cache *globalCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
now := time.Now()
data, expiry, ok := cache.get(key)
if ok && now.Before(expiry) {
return data, nil
}
v, err, _ := cache.singleflight.Do(string(key), func() (interface{}, error) {
data, expiry, ok := cache.get(key)
if ok && now.Before(expiry) {
return data, nil
}
value, err := update(ctx)
if err != nil {
return nil, err
}
cache.set(key, value)
return value, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *globalCache) Invalidate(key []byte) {
cache.mu.Lock()
cache.fastcache.Del(key)
cache.mu.Unlock()
}
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
cache.mu.RLock()
item := cache.fastcache.Get(nil, k)
cache.mu.RUnlock()
if len(item) < 8 {
return data, expiry, false
}
unix, data := binary.LittleEndian.Uint64(item), item[8:]
expiry = time.UnixMilli(int64(unix))
return data, expiry, true
}
func (cache *globalCache) set(k, v []byte) {
unix := time.Now().Add(cache.ttl).UnixMilli()
item := make([]byte, len(v)+8)
binary.LittleEndian.PutUint64(item, uint64(unix))
copy(item[8:], v)
cache.mu.Lock()
cache.fastcache.Set(k, item)
cache.mu.Unlock()
}