pomerium/pkg/storage/redis/stream.go

128 lines
2.6 KiB
Go

package redis
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-redis/redis/v8"
"github.com/golang/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
type recordStream struct {
ctx context.Context
backend *Backend
changed chan context.Context
serverVersion uint64
recordVersion uint64
record *databroker.Record
err error
closeOnce sync.Once
closed chan struct{}
}
func newRecordStream(ctx context.Context, backend *Backend, serverVersion, recordVersion uint64) *recordStream {
stream := &recordStream{
ctx: ctx,
backend: backend,
changed: backend.onChange.Bind(),
serverVersion: serverVersion,
recordVersion: recordVersion,
closed: make(chan struct{}),
}
// if the backend is closed, close the stream
go func() {
select {
case <-stream.closed:
case <-backend.closed:
_ = stream.Close()
}
}()
return stream
}
func (stream *recordStream) Close() error {
stream.closeOnce.Do(func() {
stream.backend.onChange.Unbind(stream.changed)
close(stream.closed)
})
return nil
}
func (stream *recordStream) Next(block bool) bool {
if stream.err != nil {
return false
}
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
changeCtx := context.Background()
for {
serverVersion, err := stream.backend.getOrCreateServerVersion(stream.ctx)
if err != nil {
stream.err = err
return false
}
if stream.serverVersion != serverVersion {
stream.err = storage.ErrInvalidServerVersion
return false
}
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", stream.recordVersion),
Max: "+inf",
Offset: 0,
Count: 1,
})
results, err := cmd.Result()
if err != nil {
stream.err = err
return false
}
if len(results) > 0 {
result := results[0]
var record databroker.Record
err = proto.Unmarshal([]byte(result), &record)
if err != nil {
log.Warn(changeCtx).Err(err).Msg("redis: invalid record detected")
} else {
stream.record = &record
}
stream.recordVersion++
return true
}
if block {
select {
case <-stream.ctx.Done():
stream.err = stream.ctx.Err()
return false
case <-stream.closed:
return false
case <-ticker.C: // check again
case changeCtx = <-stream.changed: // check again
}
} else {
return false
}
}
}
func (stream *recordStream) Record() *databroker.Record {
return stream.record
}
func (stream *recordStream) Err() error {
return stream.err
}