mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 03:46:29 +02:00
172 lines
4.1 KiB
Go
172 lines
4.1 KiB
Go
package redis
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
func newSyncRecordStream(
|
|
ctx context.Context,
|
|
backend *Backend,
|
|
recordType string,
|
|
serverVersion uint64,
|
|
recordVersion uint64,
|
|
) storage.RecordStream {
|
|
changed := backend.onChange.Bind()
|
|
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
|
// 1. stream all record changes
|
|
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
|
ticker := time.NewTicker(watchPollInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
currentServerVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if serverVersion != currentServerVersion {
|
|
return nil, storage.ErrInvalidServerVersion
|
|
}
|
|
|
|
record, err := nextChangedRecord(ctx, backend, recordType, &recordVersion)
|
|
if err == nil {
|
|
return record, nil
|
|
} else if !errors.Is(err, storage.ErrStreamDone) {
|
|
return nil, err
|
|
}
|
|
|
|
if !block {
|
|
return nil, storage.ErrStreamDone
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-ticker.C:
|
|
case <-changed:
|
|
}
|
|
}
|
|
},
|
|
}, func() {
|
|
backend.onChange.Unbind(changed)
|
|
})
|
|
}
|
|
|
|
func newSyncLatestRecordStream(
|
|
ctx context.Context,
|
|
backend *Backend,
|
|
recordType string,
|
|
expr storage.FilterExpression,
|
|
) (storage.RecordStream, error) {
|
|
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if recordType != "" {
|
|
filter = filter.And(func(record *databroker.Record) (keep bool) {
|
|
return record.GetType() == recordType
|
|
})
|
|
}
|
|
|
|
var cursor uint64
|
|
scannedOnce := false
|
|
var scannedRecords []*databroker.Record
|
|
generator := storage.FilteredRecordStreamGenerator(
|
|
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
|
for {
|
|
if len(scannedRecords) > 0 {
|
|
record := scannedRecords[0]
|
|
scannedRecords = scannedRecords[1:]
|
|
return record, nil
|
|
}
|
|
|
|
// the cursor is reset to 0 after iteration is complete
|
|
if scannedOnce && cursor == 0 {
|
|
return nil, storage.ErrStreamDone
|
|
}
|
|
|
|
var err error
|
|
scannedRecords, err = nextScannedRecords(ctx, backend, &cursor)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
scannedOnce = true
|
|
}
|
|
},
|
|
filter,
|
|
)
|
|
|
|
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
|
generator,
|
|
}, nil), nil
|
|
}
|
|
|
|
func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) {
|
|
var values []string
|
|
var err error
|
|
values, *cursor, err = backend.client.HScan(ctx, recordHashKey, *cursor, "", 0).Result()
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, storage.ErrStreamDone
|
|
} else if err != nil {
|
|
return nil, err
|
|
} else if len(values) == 0 {
|
|
return nil, storage.ErrStreamDone
|
|
}
|
|
|
|
var records []*databroker.Record
|
|
for i := 1; i < len(values); i += 2 {
|
|
var record databroker.Record
|
|
err := proto.Unmarshal([]byte(values[i]), &record)
|
|
if err != nil {
|
|
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
continue
|
|
}
|
|
records = append(records, &record)
|
|
}
|
|
return records, nil
|
|
}
|
|
|
|
func nextChangedRecord(ctx context.Context, backend *Backend, recordType string, recordVersion *uint64) (*databroker.Record, error) {
|
|
for {
|
|
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
|
Min: fmt.Sprintf("(%d", *recordVersion),
|
|
Max: "+inf",
|
|
Offset: 0,
|
|
Count: 1,
|
|
})
|
|
results, err := cmd.Result()
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, storage.ErrStreamDone
|
|
} else if err != nil {
|
|
return nil, err
|
|
} else if len(results) == 0 {
|
|
return nil, storage.ErrStreamDone
|
|
}
|
|
|
|
result := results[0]
|
|
var record databroker.Record
|
|
err = proto.Unmarshal([]byte(result), &record)
|
|
if err != nil {
|
|
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
*recordVersion++
|
|
continue
|
|
}
|
|
|
|
*recordVersion = record.GetVersion()
|
|
if recordType != "" && record.GetType() != recordType {
|
|
continue
|
|
}
|
|
|
|
return &record, nil
|
|
}
|
|
}
|