mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
Since when storage backend like redis can be fault in many cases, the interface should return error for the caller to handle.
336 lines
8.4 KiB
Go
336 lines
8.4 KiB
Go
// Package databroker contains a data broker implementation.
|
|
package databroker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/ptypes"
|
|
"github.com/golang/protobuf/ptypes/empty"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
|
"github.com/pomerium/pomerium/pkg/storage/redis"
|
|
)
|
|
|
|
const (
|
|
recordTypeServerVersion = "server_version"
|
|
serverVersionKey = "version"
|
|
)
|
|
|
|
// Server implements the databroker service using an in memory database.
|
|
type Server struct {
|
|
version string
|
|
cfg *serverConfig
|
|
log zerolog.Logger
|
|
|
|
mu sync.RWMutex
|
|
byType map[string]storage.Backend
|
|
onchange *Signal
|
|
}
|
|
|
|
// New creates a new server.
|
|
func New(options ...ServerOption) *Server {
|
|
cfg := newServerConfig(options...)
|
|
srv := &Server{
|
|
version: uuid.New().String(),
|
|
cfg: cfg,
|
|
log: log.With().Str("service", "databroker").Logger(),
|
|
|
|
byType: make(map[string]storage.Backend),
|
|
onchange: NewSignal(),
|
|
}
|
|
srv.initVersion()
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
var recordTypes []string
|
|
srv.mu.RLock()
|
|
for recordType := range srv.byType {
|
|
recordTypes = append(recordTypes, recordType)
|
|
}
|
|
srv.mu.RUnlock()
|
|
|
|
for _, recordType := range recordTypes {
|
|
db, err := srv.getDB(recordType)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
|
}
|
|
}
|
|
}()
|
|
return srv
|
|
}
|
|
|
|
func (srv *Server) initVersion() {
|
|
dbServerVersion, err := srv.getDB(recordTypeServerVersion)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("failed to init server version")
|
|
return
|
|
}
|
|
|
|
// Get version from storage first.
|
|
if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
|
|
var sv databroker.ServerVersion
|
|
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
|
|
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
|
|
srv.version = sv.Version
|
|
}
|
|
return
|
|
}
|
|
|
|
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
|
|
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
|
|
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
|
}
|
|
}
|
|
|
|
// Delete deletes a record from the in-memory list.
|
|
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
|
|
defer span.End()
|
|
srv.log.Info().
|
|
Str("type", req.GetType()).
|
|
Str("id", req.GetId()).
|
|
Msg("delete")
|
|
|
|
defer srv.onchange.Broadcast()
|
|
|
|
db, err := srv.getDB(req.GetType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := db.Delete(ctx, req.GetId()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return new(empty.Empty), nil
|
|
}
|
|
|
|
// Get gets a record from the in-memory list.
|
|
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.grpc.Get")
|
|
defer span.End()
|
|
srv.log.Info().
|
|
Str("type", req.GetType()).
|
|
Str("id", req.GetId()).
|
|
Msg("get")
|
|
|
|
db, err := srv.getDB(req.GetType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
record, err := db.Get(ctx, req.GetId())
|
|
if err != nil {
|
|
return nil, status.Error(codes.NotFound, "record not found")
|
|
}
|
|
return &databroker.GetResponse{Record: record}, nil
|
|
}
|
|
|
|
// GetAll gets all the records from the in-memory list.
|
|
func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*databroker.GetAllResponse, error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.grpc.GetAll")
|
|
defer span.End()
|
|
srv.log.Info().
|
|
Str("type", req.GetType()).
|
|
Msg("get all")
|
|
|
|
db, err := srv.getDB(req.GetType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
records, err := db.GetAll(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var recordVersion string
|
|
for _, record := range records {
|
|
if record.GetVersion() > recordVersion {
|
|
recordVersion = record.GetVersion()
|
|
}
|
|
}
|
|
return &databroker.GetAllResponse{
|
|
ServerVersion: srv.version,
|
|
RecordVersion: recordVersion,
|
|
Records: records,
|
|
}, nil
|
|
}
|
|
|
|
// Set updates a record in the in-memory list, or adds a new one.
|
|
func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databroker.SetResponse, error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.grpc.Set")
|
|
defer span.End()
|
|
srv.log.Info().
|
|
Str("type", req.GetType()).
|
|
Str("id", req.GetId()).
|
|
Msg("set")
|
|
|
|
defer srv.onchange.Broadcast()
|
|
|
|
db, err := srv.getDB(req.GetType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
|
return nil, err
|
|
}
|
|
record, err := db.Get(ctx, req.GetId())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &databroker.SetResponse{
|
|
Record: record,
|
|
ServerVersion: srv.version,
|
|
}, nil
|
|
}
|
|
|
|
// Sync streams updates for the given record type.
|
|
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
|
|
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
|
|
defer span.End()
|
|
srv.log.Info().
|
|
Str("type", req.GetType()).
|
|
Str("server_version", req.GetServerVersion()).
|
|
Str("record_version", req.GetRecordVersion()).
|
|
Msg("sync")
|
|
|
|
recordVersion := req.GetRecordVersion()
|
|
// reset record version if the server versions don't match
|
|
if req.GetServerVersion() != srv.version {
|
|
recordVersion = ""
|
|
}
|
|
|
|
db, err := srv.getDB(req.GetType())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ch := srv.onchange.Bind()
|
|
defer srv.onchange.Unbind(ch)
|
|
for {
|
|
updated, _ := db.List(context.Background(), recordVersion)
|
|
if len(updated) > 0 {
|
|
sort.Slice(updated, func(i, j int) bool {
|
|
return updated[i].Version < updated[j].Version
|
|
})
|
|
recordVersion = updated[len(updated)-1].Version
|
|
err := stream.Send(&databroker.SyncResponse{
|
|
ServerVersion: srv.version,
|
|
Records: updated,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-stream.Context().Done():
|
|
return stream.Context().Err()
|
|
case <-ch:
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetTypes returns all the known record types.
|
|
func (srv *Server) GetTypes(ctx context.Context, _ *emptypb.Empty) (*databroker.GetTypesResponse, error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.grpc.GetTypes")
|
|
defer span.End()
|
|
var recordTypes []string
|
|
srv.mu.RLock()
|
|
for recordType := range srv.byType {
|
|
recordTypes = append(recordTypes, recordType)
|
|
}
|
|
srv.mu.RUnlock()
|
|
|
|
sort.Strings(recordTypes)
|
|
return &databroker.GetTypesResponse{
|
|
Types: recordTypes,
|
|
}, nil
|
|
}
|
|
|
|
// SyncTypes synchronizes all the known record types.
|
|
func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerService_SyncTypesServer) error {
|
|
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncTypes")
|
|
defer span.End()
|
|
srv.log.Info().
|
|
Msg("sync types")
|
|
|
|
ch := srv.onchange.Bind()
|
|
defer srv.onchange.Unbind(ch)
|
|
|
|
var prev []string
|
|
for {
|
|
res, err := srv.GetTypes(stream.Context(), req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if prev == nil || !reflect.DeepEqual(prev, res.Types) {
|
|
err := stream.Send(res)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
prev = res.Types
|
|
}
|
|
|
|
select {
|
|
case <-stream.Context().Done():
|
|
return stream.Context().Err()
|
|
case <-ch:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (srv *Server) getDB(recordType string) (storage.Backend, error) {
|
|
// double-checked locking:
|
|
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
|
srv.mu.RLock()
|
|
db := srv.byType[recordType]
|
|
srv.mu.RUnlock()
|
|
if db == nil {
|
|
srv.mu.Lock()
|
|
db = srv.byType[recordType]
|
|
var err error
|
|
if db == nil {
|
|
db, err = srv.newDB(recordType)
|
|
srv.byType[recordType] = db
|
|
}
|
|
srv.mu.Unlock()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func (srv *Server) newDB(recordType string) (storage.Backend, error) {
|
|
switch srv.cfg.storageType {
|
|
case inmemory.Name:
|
|
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
|
case redis.Name:
|
|
db, err := redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
|
}
|
|
return db, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
|
}
|
|
}
|