mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
128 lines
2.6 KiB
Go
128 lines
2.6 KiB
Go
package redis
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/golang/protobuf/proto"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
type recordStream struct {
|
|
ctx context.Context
|
|
backend *Backend
|
|
|
|
changed chan context.Context
|
|
serverVersion uint64
|
|
recordVersion uint64
|
|
record *databroker.Record
|
|
err error
|
|
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
}
|
|
|
|
func newRecordStream(ctx context.Context, backend *Backend, serverVersion, recordVersion uint64) *recordStream {
|
|
stream := &recordStream{
|
|
ctx: ctx,
|
|
backend: backend,
|
|
|
|
changed: backend.onChange.Bind(),
|
|
serverVersion: serverVersion,
|
|
recordVersion: recordVersion,
|
|
|
|
closed: make(chan struct{}),
|
|
}
|
|
// if the backend is closed, close the stream
|
|
go func() {
|
|
select {
|
|
case <-stream.closed:
|
|
case <-backend.closed:
|
|
_ = stream.Close()
|
|
}
|
|
}()
|
|
return stream
|
|
}
|
|
|
|
func (stream *recordStream) Close() error {
|
|
stream.closeOnce.Do(func() {
|
|
stream.backend.onChange.Unbind(stream.changed)
|
|
close(stream.closed)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (stream *recordStream) Next(block bool) bool {
|
|
if stream.err != nil {
|
|
return false
|
|
}
|
|
|
|
ticker := time.NewTicker(watchPollInterval)
|
|
defer ticker.Stop()
|
|
|
|
changeCtx := context.Background()
|
|
for {
|
|
serverVersion, err := stream.backend.getOrCreateServerVersion(stream.ctx)
|
|
if err != nil {
|
|
stream.err = err
|
|
return false
|
|
}
|
|
if stream.serverVersion != serverVersion {
|
|
stream.err = storage.ErrInvalidServerVersion
|
|
return false
|
|
}
|
|
|
|
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
|
|
Min: fmt.Sprintf("(%d", stream.recordVersion),
|
|
Max: "+inf",
|
|
Offset: 0,
|
|
Count: 1,
|
|
})
|
|
results, err := cmd.Result()
|
|
if err != nil {
|
|
stream.err = err
|
|
return false
|
|
}
|
|
|
|
if len(results) > 0 {
|
|
result := results[0]
|
|
var record databroker.Record
|
|
err = proto.Unmarshal([]byte(result), &record)
|
|
if err != nil {
|
|
log.Warn(changeCtx).Err(err).Msg("redis: invalid record detected")
|
|
} else {
|
|
stream.record = &record
|
|
}
|
|
stream.recordVersion++
|
|
return true
|
|
}
|
|
|
|
if block {
|
|
select {
|
|
case <-stream.ctx.Done():
|
|
stream.err = stream.ctx.Err()
|
|
return false
|
|
case <-stream.closed:
|
|
return false
|
|
case <-ticker.C: // check again
|
|
case changeCtx = <-stream.changed: // check again
|
|
}
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
|
|
func (stream *recordStream) Record() *databroker.Record {
|
|
return stream.record
|
|
}
|
|
|
|
func (stream *recordStream) Err() error {
|
|
return stream.err
|
|
}
|