cache : add cache service (#457)

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2020-01-20 18:25:34 -08:00 committed by GitHub
parent 8a9cb0f803
commit dccc7cd2ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1837 additions and 587 deletions

View file

@ -1,14 +1,12 @@
package cache // import "github.com/pomerium/pomerium/internal/sessions/cache"
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/golang/groupcache"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
)
@ -16,83 +14,69 @@ import (
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
const (
defaultQueryParamKey = "ati"
)
// Store implements the session store interface using a distributed cache.
// Store implements the session store interface using a cache service.
type Store struct {
name string
encoder encoding.Marshaler
decoder encoding.Unmarshaler
cache *groupcache.Group
cache client.Cacher
encoder encoding.MarshalUnmarshaler
queryParam string
wrappedStore sessions.SessionStore
}
// defaultCacheSize is ~10MB
var defaultCacheSize int64 = 10 << 20
// NewStore creates a new session store built on the distributed caching library
// groupcache. On a cache miss, the cache store attempts to fallback to another
// SessionStore implementation.
func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store {
store := &Store{
name: name,
encoder: enc,
decoder: enc,
wrappedStore: wrappedStore,
}
store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc(
func(ctx context.Context, id string, dest groupcache.Sink) error {
// fill the cache with session set as part of the request
// context set previously as part of SaveSession.
b := fromContext(ctx)
if len(b) == 0 {
return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id)
}
if err := dest.SetBytes(b); err != nil {
return fmt.Errorf("sessions/cache: sink error %w", err)
}
return nil
},
))
return store
// Options represent cache store's available configurations.
type Options struct {
Cache client.Cacher
Encoder encoding.MarshalUnmarshaler
QueryParam string
WrappedStore sessions.SessionStore
}
// LoadSession implements SessionLoaders's LoadSession method for cache store.
var defaultOptions = &Options{
QueryParam: "cache_store_key",
}
// NewStore creates a new cache
func NewStore(o *Options) *Store {
if o.QueryParam == "" {
o.QueryParam = defaultOptions.QueryParam
}
return &Store{
cache: o.Cache,
encoder: o.Encoder,
queryParam: o.QueryParam,
wrappedStore: o.WrappedStore,
}
}
// LoadSession looks for a preset query parameter in the request body
// representing the key to lookup from the cache.
func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) {
// look for our cache's key in the default query param
sessionID := r.URL.Query().Get(defaultQueryParamKey)
sessionID := r.URL.Query().Get(s.queryParam)
if sessionID == "" {
// if unset, fallback to default cache store
log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader")
return s.wrappedStore.LoadSession(r)
return nil, sessions.ErrNoSessionFound
}
var b []byte
if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader")
return s.wrappedStore.LoadSession(r)
exists, val, err := s.cache.Get(r.Context(), sessionID)
if err != nil {
log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader")
return nil, err
}
if !exists {
return nil, sessions.ErrNoSessionFound
}
var session sessions.State
if err := s.decoder.Unmarshal(b, &session); err != nil {
if err := s.encoder.Unmarshal(val, &session); err != nil {
log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal")
return nil, sessions.ErrMalformed
}
return &session, nil
}
// ClearSession implements SessionStore's ClearSession for the cache store.
// Since group cache has no explicit eviction, we just call the wrapped
// store's ClearSession method here.
// ClearSession clears the session from the wrapped store.
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
s.wrappedStore.ClearSession(w, r)
}
// SaveSession implements SessionStore's SaveSession method for cache store.
// SaveSession saves the session to the cache, and wrapped store.
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
err := s.wrappedStore.SaveSession(w, r, x)
if err != nil {
@ -101,7 +85,7 @@ func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{
state, ok := x.(*sessions.State)
if !ok {
return errors.New("internal/sessions: cannot cache non state type")
return errors.New("sessions/cache: cannot cache non state type")
}
data, err := s.encoder.Marshal(&state)
@ -109,23 +93,5 @@ func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{
return fmt.Errorf("sessions/cache: marshal %w", err)
}
ctx := newContext(r.Context(), data)
var b []byte
return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b))
}
var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"}
type contextKey struct {
name string
}
func newContext(ctx context.Context, b []byte) context.Context {
ctx = context.WithValue(ctx, sessionCtxKey, b)
return ctx
}
func fromContext(ctx context.Context) []byte {
b, _ := ctx.Value(sessionCtxKey).([]byte)
return b
return s.cache.Set(r.Context(), state.AccessTokenID, data)
}

View file

@ -1,7 +1,8 @@
package cache
import (
"fmt"
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
@ -9,125 +10,188 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/mock"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
type mockCache struct {
Key string
KeyExists bool
Value []byte
Err error
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
return mc.KeyExists, mc.Value, mc.Err
}
func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error {
return mc.Err
}
func (mc *mockCache) Close() error {
return mc.Err
}
func TestNewStore(t *testing.T) {
tests := []struct {
name string
skipSave bool
cacheSize int64
state sessions.State
name string
Options *Options
State *sessions.State
wantBody string
wantStatus int
wantErr bool
wantLoadErr bool
wantStatus int
}{
{"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"simple good",
&Options{
Cache: &mockCache{},
WrappedStore: &mock.Store{},
Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")},
},
&sessions.State{Email: "user@domain.com", User: "user"},
false, false,
http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defaultCacheSize = tt.cacheSize
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder)
if err != nil {
t.Fatal(err)
}
cacheStore := NewStore(encoder, cs, t.Name())
got := NewStore(tt.Options)
r := httptest.NewRequest(http.MethodGet, "/", nil)
q := r.URL.Query()
q.Set(defaultQueryParamKey, tt.state.AccessTokenID)
r.URL.RawQuery = q.Encode()
r.Header.Set("Accept", "application/json")
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh)))
if !tt.skipSave {
cacheStore.SaveSession(w, r, &tt.state)
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
for i := 1; i <= 10; i++ {
s := tt.state
s.AccessTokenID = cryptutil.NewBase64Key()
cacheStore.SaveSession(w, r, s)
}
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
got.ClearSession(w, r)
status := w.Result().StatusCode
if diff := cmp.Diff(status, tt.wantStatus); diff != "" {
t.Errorf("ClearSession() = %v", diff)
}
})
}
}
func TestStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
Options *Options
x interface{}
wantErr bool
}{
{"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
{"bad type", "bad type!", true},
{"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
{"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
{"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
{"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
o := tt.Options
if o.WrappedStore == nil {
o.WrappedStore = cs
}
cs, err := cookie.NewStore(&cookie.Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
cacheStore := NewStore(encoder, cs, t.Name())
cacheStore := NewStore(tt.Options)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestStore_LoadSession(t *testing.T) {
key := cryptutil.NewBase64Key()
tests := []struct {
name string
state *sessions.State
cache client.Cacher
encoder encoding.MarshalUnmarshaler
queryParam string
wrappedStore sessions.SessionStore
wantErr bool
}{
{"good",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
false},
{"missing param with key",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
"bad_query",
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"doesn't exist",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: false},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"retrieval error",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{Err: errors.New("err")},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"unmarshal failure",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{UnmarshalError: errors.New("err")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Store{
cache: tt.cache,
encoder: tt.encoder,
queryParam: tt.queryParam,
wrappedStore: tt.wrappedStore,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
q := r.URL.Query()
q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID)
r.URL.RawQuery = q.Encode()
r.Header.Set("Accept", "application/json")
_, err := s.LoadSession(r)
if (err != nil) != tt.wantErr {
t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}