pomerium/internal/cliutil/jwtcache.go
2021-06-10 09:35:44 -06:00

164 lines
3.5 KiB
Go

package cliutil
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/martinlindhe/base36"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// predefined cache errors
var (
ErrExpired = errors.New("expired")
ErrInvalid = errors.New("invalid")
ErrNotFound = errors.New("not found")
)
// A JWTCache loads and stores JWTs.
type JWTCache interface {
DeleteJWT(key string) error
LoadJWT(key string) (rawJWT string, err error)
StoreJWT(key string, rawJWT string) error
}
// A LocalJWTCache stores files in the user's cache directory.
type LocalJWTCache struct {
dir string
}
// NewLocalJWTCache creates a new LocalJWTCache.
func NewLocalJWTCache() (*LocalJWTCache, error) {
root, err := os.UserCacheDir()
if err != nil {
return nil, err
}
dir := filepath.Join(root, "pomerium-cli", "jwts")
err = os.MkdirAll(dir, 0o755)
if err != nil {
return nil, fmt.Errorf("error creating user cache directory: %w", err)
}
return &LocalJWTCache{
dir: dir,
}, nil
}
// DeleteJWT deletes a raw JWT from the local cache.
func (cache *LocalJWTCache) DeleteJWT(key string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := os.Remove(path)
if os.IsNotExist(err) {
err = nil
}
return err
}
// LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
path := filepath.Join(cache.dir, cache.fileName(key))
rawBS, err := ioutil.ReadFile(path)
if os.IsNotExist(err) {
return "", ErrNotFound
} else if err != nil {
return "", err
}
rawJWT = string(rawBS)
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a raw JWT in the local cache.
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := ioutil.WriteFile(path, []byte(rawJWT), 0o600)
if err != nil {
return err
}
return nil
}
func (cache *LocalJWTCache) hash(str string) string {
h := cryptutil.Hash("LocalJWTCache", []byte(str))
return base36.EncodeBytes(h)
}
func (cache *LocalJWTCache) fileName(key string) string {
return cache.hash(key) + ".jwt"
}
// A MemoryJWTCache stores JWTs in an in-memory map.
type MemoryJWTCache struct {
mu sync.Mutex
entries map[string]string
}
// NewMemoryJWTCache creates a new in-memory JWT cache.
func NewMemoryJWTCache() *MemoryJWTCache {
return &MemoryJWTCache{entries: make(map[string]string)}
}
// DeleteJWT deletes a JWT from the in-memory map.
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
delete(cache.entries, key)
return nil
}
// LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
cache.mu.Lock()
defer cache.mu.Unlock()
rawJWT, ok := cache.entries[key]
if !ok {
return "", ErrNotFound
}
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a JWT in the in-memory map.
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.entries[key] = rawJWT
return nil
}
func checkExpiry(rawJWT string) error {
tok, err := jose.ParseSigned(rawJWT)
if err != nil {
return ErrInvalid
}
var claims struct {
Expiry int64 `json:"exp"`
}
err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
if err != nil {
return ErrInvalid
}
expiresAt := time.Unix(claims.Expiry, 0)
if expiresAt.Before(time.Now()) {
return ErrExpired
}
return nil
}