mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
cache : add cache service (#457)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
8a9cb0f803
commit
dccc7cd2ff
46 changed files with 1837 additions and 587 deletions
122
internal/sessions/cache/cache_store.go
vendored
122
internal/sessions/cache/cache_store.go
vendored
|
@ -1,14 +1,12 @@
|
|||
package cache // import "github.com/pomerium/pomerium/internal/sessions/cache"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang/groupcache"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
@ -16,83 +14,69 @@ import (
|
|||
var _ sessions.SessionStore = &Store{}
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
const (
|
||||
defaultQueryParamKey = "ati"
|
||||
)
|
||||
|
||||
// Store implements the session store interface using a distributed cache.
|
||||
// Store implements the session store interface using a cache service.
|
||||
type Store struct {
|
||||
name string
|
||||
encoder encoding.Marshaler
|
||||
decoder encoding.Unmarshaler
|
||||
|
||||
cache *groupcache.Group
|
||||
cache client.Cacher
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
queryParam string
|
||||
wrappedStore sessions.SessionStore
|
||||
}
|
||||
|
||||
// defaultCacheSize is ~10MB
|
||||
var defaultCacheSize int64 = 10 << 20
|
||||
|
||||
// NewStore creates a new session store built on the distributed caching library
|
||||
// groupcache. On a cache miss, the cache store attempts to fallback to another
|
||||
// SessionStore implementation.
|
||||
func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store {
|
||||
store := &Store{
|
||||
name: name,
|
||||
encoder: enc,
|
||||
decoder: enc,
|
||||
wrappedStore: wrappedStore,
|
||||
}
|
||||
|
||||
store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc(
|
||||
func(ctx context.Context, id string, dest groupcache.Sink) error {
|
||||
// fill the cache with session set as part of the request
|
||||
// context set previously as part of SaveSession.
|
||||
b := fromContext(ctx)
|
||||
if len(b) == 0 {
|
||||
return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id)
|
||||
}
|
||||
if err := dest.SetBytes(b); err != nil {
|
||||
return fmt.Errorf("sessions/cache: sink error %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
return store
|
||||
// Options represent cache store's available configurations.
|
||||
type Options struct {
|
||||
Cache client.Cacher
|
||||
Encoder encoding.MarshalUnmarshaler
|
||||
QueryParam string
|
||||
WrappedStore sessions.SessionStore
|
||||
}
|
||||
|
||||
// LoadSession implements SessionLoaders's LoadSession method for cache store.
|
||||
var defaultOptions = &Options{
|
||||
QueryParam: "cache_store_key",
|
||||
}
|
||||
|
||||
// NewStore creates a new cache
|
||||
func NewStore(o *Options) *Store {
|
||||
if o.QueryParam == "" {
|
||||
o.QueryParam = defaultOptions.QueryParam
|
||||
}
|
||||
return &Store{
|
||||
cache: o.Cache,
|
||||
encoder: o.Encoder,
|
||||
queryParam: o.QueryParam,
|
||||
wrappedStore: o.WrappedStore,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSession looks for a preset query parameter in the request body
|
||||
// representing the key to lookup from the cache.
|
||||
func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||
// look for our cache's key in the default query param
|
||||
sessionID := r.URL.Query().Get(defaultQueryParamKey)
|
||||
sessionID := r.URL.Query().Get(s.queryParam)
|
||||
if sessionID == "" {
|
||||
// if unset, fallback to default cache store
|
||||
log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader")
|
||||
return s.wrappedStore.LoadSession(r)
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
|
||||
var b []byte
|
||||
if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader")
|
||||
return s.wrappedStore.LoadSession(r)
|
||||
exists, val, err := s.cache.Get(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader")
|
||||
return nil, err
|
||||
}
|
||||
if !exists {
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
var session sessions.State
|
||||
if err := s.decoder.Unmarshal(b, &session); err != nil {
|
||||
if err := s.encoder.Unmarshal(val, &session); err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal")
|
||||
return nil, sessions.ErrMalformed
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ClearSession implements SessionStore's ClearSession for the cache store.
|
||||
// Since group cache has no explicit eviction, we just call the wrapped
|
||||
// store's ClearSession method here.
|
||||
// ClearSession clears the session from the wrapped store.
|
||||
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
s.wrappedStore.ClearSession(w, r)
|
||||
}
|
||||
|
||||
// SaveSession implements SessionStore's SaveSession method for cache store.
|
||||
// SaveSession saves the session to the cache, and wrapped store.
|
||||
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||
err := s.wrappedStore.SaveSession(w, r, x)
|
||||
if err != nil {
|
||||
|
@ -101,7 +85,7 @@ func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{
|
|||
|
||||
state, ok := x.(*sessions.State)
|
||||
if !ok {
|
||||
return errors.New("internal/sessions: cannot cache non state type")
|
||||
return errors.New("sessions/cache: cannot cache non state type")
|
||||
}
|
||||
|
||||
data, err := s.encoder.Marshal(&state)
|
||||
|
@ -109,23 +93,5 @@ func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{
|
|||
return fmt.Errorf("sessions/cache: marshal %w", err)
|
||||
}
|
||||
|
||||
ctx := newContext(r.Context(), data)
|
||||
var b []byte
|
||||
return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b))
|
||||
}
|
||||
|
||||
var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"}
|
||||
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func newContext(ctx context.Context, b []byte) context.Context {
|
||||
ctx = context.WithValue(ctx, sessionCtxKey, b)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func fromContext(ctx context.Context) []byte {
|
||||
b, _ := ctx.Value(sessionCtxKey).([]byte)
|
||||
return b
|
||||
return s.cache.Set(r.Context(), state.AccessTokenID, data)
|
||||
}
|
||||
|
|
220
internal/sessions/cache/cache_store_test.go
vendored
220
internal/sessions/cache/cache_store_test.go
vendored
|
@ -1,7 +1,8 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -9,125 +10,188 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
type mockCache struct {
|
||||
Key string
|
||||
KeyExists bool
|
||||
Value []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
|
||||
return mc.KeyExists, mc.Value, mc.Err
|
||||
}
|
||||
func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error {
|
||||
return mc.Err
|
||||
}
|
||||
func (mc *mockCache) Close() error {
|
||||
return mc.Err
|
||||
}
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
skipSave bool
|
||||
cacheSize int64
|
||||
state sessions.State
|
||||
name string
|
||||
Options *Options
|
||||
State *sessions.State
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
wantErr bool
|
||||
wantLoadErr bool
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"simple good",
|
||||
&Options{
|
||||
Cache: &mockCache{},
|
||||
WrappedStore: &mock.Store{},
|
||||
Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
},
|
||||
&sessions.State{Email: "user@domain.com", User: "user"},
|
||||
false, false,
|
||||
http.StatusOK},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defaultCacheSize = tt.cacheSize
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cacheStore := NewStore(encoder, cs, t.Name())
|
||||
got := NewStore(tt.Options)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
q := r.URL.Query()
|
||||
|
||||
q.Set(defaultQueryParamKey, tt.state.AccessTokenID)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh)))
|
||||
|
||||
if !tt.skipSave {
|
||||
cacheStore.SaveSession(w, r, &tt.state)
|
||||
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
for i := 1; i <= 10; i++ {
|
||||
s := tt.state
|
||||
s.AccessTokenID = cryptutil.NewBase64Key()
|
||||
cacheStore.SaveSession(w, r, s)
|
||||
}
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
got.ClearSession(w, r)
|
||||
status := w.Result().StatusCode
|
||||
if diff := cmp.Diff(status, tt.wantStatus); diff != "" {
|
||||
t.Errorf("ClearSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_SaveSession(t *testing.T) {
|
||||
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cs, err := cookie.NewStore(&cookie.Options{
|
||||
Name: "_pomerium",
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
Options *Options
|
||||
|
||||
x interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
|
||||
{"bad type", "bad type!", true},
|
||||
{"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
|
||||
{"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
|
||||
{"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
|
||||
{"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
o := tt.Options
|
||||
if o.WrappedStore == nil {
|
||||
o.WrappedStore = cs
|
||||
|
||||
}
|
||||
cs, err := cookie.NewStore(&cookie.Options{
|
||||
Name: "_pomerium",
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cacheStore := NewStore(encoder, cs, t.Name())
|
||||
cacheStore := NewStore(tt.Options)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_LoadSession(t *testing.T) {
|
||||
key := cryptutil.NewBase64Key()
|
||||
tests := []struct {
|
||||
name string
|
||||
state *sessions.State
|
||||
cache client.Cacher
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
queryParam string
|
||||
wrappedStore sessions.SessionStore
|
||||
wantErr bool
|
||||
}{
|
||||
{"good",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: true},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
false},
|
||||
{"missing param with key",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: true},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
"bad_query",
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
{"doesn't exist",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: false},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
{"retrieval error",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{Err: errors.New("err")},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
{"unmarshal failure",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: true},
|
||||
mock_encoder.Encoder{UnmarshalError: errors.New("err")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Store{
|
||||
cache: tt.cache,
|
||||
encoder: tt.encoder,
|
||||
queryParam: tt.queryParam,
|
||||
wrappedStore: tt.wrappedStore,
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
q := r.URL.Query()
|
||||
|
||||
q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
_, err := s.LoadSession(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue