pomerium/internal/databroker/server.go
Cuong Manh Le 99785cbb5b
internal/databroker: store server version (#1121)
Storing server version when creating new server. After then, we can
retrieve the version from backend when server restart.

With storage backend which supports persistent, the server version
won't change after restarting.
2020-07-22 03:50:22 +07:00

287 lines
7.4 KiB
Go

// Package databroker contains a data broker implementation.
package databroker
import (
"context"
"reflect"
"sort"
"sync"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/uuid"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/inmemory"
)
const (
recordTypeServerVersion = "server_version"
serverVersionKey = "version"
)
// Server implements the databroker service using an in memory database.
type Server struct {
version string
cfg *serverConfig
log zerolog.Logger
mu sync.RWMutex
byType map[string]storage.Backend
onchange *Signal
}
// New creates a new server.
func New(options ...ServerOption) *Server {
cfg := newServerConfig(options...)
srv := &Server{
version: uuid.New().String(),
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]storage.Backend),
onchange: NewSignal(),
}
srv.initVersion()
go func() {
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
defer ticker.Stop()
for range ticker.C {
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
recordTypes = append(recordTypes, recordType)
}
srv.mu.RUnlock()
for _, recordType := range recordTypes {
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
}
}
}()
return srv
}
func (srv *Server) initVersion() {
dbServerVersion := srv.getDB(recordTypeServerVersion)
if dbServerVersion == nil {
return
}
// Get version from storage first.
if r := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
var sv databroker.ServerVersion
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err != nil {
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
srv.version = sv.Version
}
return
}
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
srv.log.Warn().Err(err).Msg("failed to save server version.")
}
}
// Delete deletes a record from the in-memory list.
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("delete")
defer srv.onchange.Broadcast()
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
return nil, err
}
return new(empty.Empty), nil
}
// Get gets a record from the in-memory list.
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Get")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("get")
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
if record == nil {
return nil, status.Error(codes.NotFound, "record not found")
}
return &databroker.GetResponse{Record: record}, nil
}
// GetAll gets all the records from the in-memory list.
func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*databroker.GetAllResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.GetAll")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Msg("get all")
records := srv.getDB(req.GetType()).GetAll(ctx)
var recordVersion string
for _, record := range records {
if record.GetVersion() > recordVersion {
recordVersion = record.GetVersion()
}
}
return &databroker.GetAllResponse{
ServerVersion: srv.version,
RecordVersion: recordVersion,
Records: records,
}, nil
}
// Set updates a record in the in-memory list, or adds a new one.
func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databroker.SetResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Set")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Str("id", req.GetId()).
Msg("set")
defer srv.onchange.Broadcast()
db := srv.getDB(req.GetType())
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
return nil, err
}
record := db.Get(ctx, req.GetId())
return &databroker.SetResponse{
Record: record,
ServerVersion: srv.version,
}, nil
}
// Sync streams updates for the given record type.
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
defer span.End()
srv.log.Info().
Str("type", req.GetType()).
Str("server_version", req.GetServerVersion()).
Str("record_version", req.GetRecordVersion()).
Msg("sync")
recordVersion := req.GetRecordVersion()
// reset record version if the server versions don't match
if req.GetServerVersion() != srv.version {
recordVersion = ""
}
db := srv.getDB(req.GetType())
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
for {
updated := db.List(context.Background(), recordVersion)
if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool {
return updated[i].Version < updated[j].Version
})
recordVersion = updated[len(updated)-1].Version
err := stream.Send(&databroker.SyncResponse{
ServerVersion: srv.version,
Records: updated,
})
if err != nil {
return err
}
}
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
}
}
}
// GetTypes returns all the known record types.
func (srv *Server) GetTypes(ctx context.Context, _ *emptypb.Empty) (*databroker.GetTypesResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.GetTypes")
defer span.End()
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
recordTypes = append(recordTypes, recordType)
}
srv.mu.RUnlock()
sort.Strings(recordTypes)
return &databroker.GetTypesResponse{
Types: recordTypes,
}, nil
}
// SyncTypes synchronizes all the known record types.
func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerService_SyncTypesServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncTypes")
defer span.End()
srv.log.Info().
Msg("sync types")
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
var prev []string
for {
res, err := srv.GetTypes(stream.Context(), req)
if err != nil {
return err
}
if prev == nil || !reflect.DeepEqual(prev, res.Types) {
err := stream.Send(res)
if err != nil {
return err
}
prev = res.Types
}
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
}
}
}
func (srv *Server) getDB(recordType string) storage.Backend {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
srv.mu.RLock()
db := srv.byType[recordType]
srv.mu.RUnlock()
if db == nil {
srv.mu.Lock()
db = srv.byType[recordType]
if db == nil {
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
srv.byType[recordType] = db
}
srv.mu.Unlock()
}
return db
}