mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
internal/databroker: handle new db error (#1129)
Since when we now support other storage, not only memory storage, we need to handle the error when we can't connect to storage.
This commit is contained in:
parent
1640151bc1
commit
1867feb5b9
2 changed files with 50 additions and 19 deletions
|
@ -3,6 +3,7 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -66,7 +67,11 @@ func New(options ...ServerOption) *Server {
|
||||||
srv.mu.RUnlock()
|
srv.mu.RUnlock()
|
||||||
|
|
||||||
for _, recordType := range recordTypes {
|
for _, recordType := range recordTypes {
|
||||||
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
db, err := srv.getDB(recordType)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -74,8 +79,9 @@ func New(options ...ServerOption) *Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) initVersion() {
|
func (srv *Server) initVersion() {
|
||||||
dbServerVersion := srv.getDB(recordTypeServerVersion)
|
dbServerVersion, err := srv.getDB(recordTypeServerVersion)
|
||||||
if dbServerVersion == nil {
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to init server version")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +112,12 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
||||||
|
|
||||||
defer srv.onchange.Broadcast()
|
defer srv.onchange.Broadcast()
|
||||||
|
|
||||||
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
|
db, err := srv.getDB(req.GetType())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Delete(ctx, req.GetId()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,7 +133,11 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("get")
|
Msg("get")
|
||||||
|
|
||||||
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
|
db, err := srv.getDB(req.GetType())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
record := db.Get(ctx, req.GetId())
|
||||||
if record == nil {
|
if record == nil {
|
||||||
return nil, status.Error(codes.NotFound, "record not found")
|
return nil, status.Error(codes.NotFound, "record not found")
|
||||||
}
|
}
|
||||||
|
@ -137,7 +152,11 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
||||||
Str("type", req.GetType()).
|
Str("type", req.GetType()).
|
||||||
Msg("get all")
|
Msg("get all")
|
||||||
|
|
||||||
records := srv.getDB(req.GetType()).GetAll(ctx)
|
db, err := srv.getDB(req.GetType())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
records := db.GetAll(ctx)
|
||||||
var recordVersion string
|
var recordVersion string
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
if record.GetVersion() > recordVersion {
|
if record.GetVersion() > recordVersion {
|
||||||
|
@ -162,7 +181,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
||||||
|
|
||||||
defer srv.onchange.Broadcast()
|
defer srv.onchange.Broadcast()
|
||||||
|
|
||||||
db := srv.getDB(req.GetType())
|
db, err := srv.getDB(req.GetType())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -190,7 +212,10 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
||||||
recordVersion = ""
|
recordVersion = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
db := srv.getDB(req.GetType())
|
db, err := srv.getDB(req.GetType())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
ch := srv.onchange.Bind()
|
ch := srv.onchange.Bind()
|
||||||
defer srv.onchange.Unbind(ch)
|
defer srv.onchange.Unbind(ch)
|
||||||
|
@ -269,7 +294,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getDB(recordType string) storage.Backend {
|
func (srv *Server) getDB(recordType string) (storage.Backend, error) {
|
||||||
// double-checked locking:
|
// double-checked locking:
|
||||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||||
srv.mu.RLock()
|
srv.mu.RLock()
|
||||||
|
@ -278,27 +303,30 @@ func (srv *Server) getDB(recordType string) storage.Backend {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
db = srv.byType[recordType]
|
db = srv.byType[recordType]
|
||||||
|
var err error
|
||||||
if db == nil {
|
if db == nil {
|
||||||
db = srv.newDB(recordType)
|
db, err = srv.newDB(recordType)
|
||||||
srv.byType[recordType] = db
|
srv.byType[recordType] = db
|
||||||
}
|
}
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return db
|
}
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) newDB(recordType string) storage.Backend {
|
func (srv *Server) newDB(recordType string) (storage.Backend, error) {
|
||||||
switch srv.cfg.storageType {
|
switch srv.cfg.storageType {
|
||||||
case inmemory.Name:
|
case inmemory.Name:
|
||||||
return inmemory.NewDB(recordType, srv.cfg.btreeDegree)
|
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
||||||
case redis.Name:
|
case redis.Name:
|
||||||
db, err := redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
|
db, err := redis.New(srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
srv.log.Error().Err(err).Msg("failed to create new redis storage")
|
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return db
|
return db, nil
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -37,7 +38,8 @@ func TestServer_initVersion(t *testing.T) {
|
||||||
t.Run("new server with random version", func(t *testing.T) {
|
t.Run("new server with random version", func(t *testing.T) {
|
||||||
srv := newServer(cfg)
|
srv := newServer(cfg)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db := srv.getDB(recordTypeServerVersion)
|
db, err := srv.getDB(recordTypeServerVersion)
|
||||||
|
require.NoError(t, err)
|
||||||
r := db.Get(ctx, serverVersionKey)
|
r := db.Get(ctx, serverVersionKey)
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
srvVersion := uuid.New().String()
|
srvVersion := uuid.New().String()
|
||||||
|
@ -53,7 +55,8 @@ func TestServer_initVersion(t *testing.T) {
|
||||||
t.Run("init version twice should get the same version", func(t *testing.T) {
|
t.Run("init version twice should get the same version", func(t *testing.T) {
|
||||||
srv := newServer(cfg)
|
srv := newServer(cfg)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db := srv.getDB(recordTypeServerVersion)
|
db, err := srv.getDB(recordTypeServerVersion)
|
||||||
|
require.NoError(t, err)
|
||||||
r := db.Get(ctx, serverVersionKey)
|
r := db.Get(ctx, serverVersionKey)
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue