remove context, add close

This commit is contained in:
Caleb Doxsey 2025-02-25 09:14:54 -07:00
parent ff127e61f9
commit 55fb69b1e7
10 changed files with 91 additions and 50 deletions

View file

@ -105,11 +105,12 @@ func (a *Authorize) loadSession(
}
// invalidate cache
for _, record := range records {
storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{
q := &databroker.QueryRequest{
Type: record.GetType(),
Query: record.GetId(),
Limit: 1,
})
}
q.SetFilterByIDOrIndex(record.GetId())
storage.GetQuerier(ctx).InvalidateCache(ctx, q)
}
return nil
},

View file

@ -136,12 +136,13 @@ func NewFileOrEnvironmentSource(
watcher: fileutil.NewWatcher(),
config: cfg,
}
context.AfterFunc(ctx, func() { src.watcher.Close() })
if configFile != "" {
if cfg.Options.IsRuntimeFlagSet(RuntimeFlagConfigHotReload) {
src.watcher.Watch(ctx, []string{configFile})
src.watcher.Watch([]string{configFile})
} else {
log.Ctx(ctx).Info().Msg("hot reload disabled")
src.watcher.Watch(ctx, nil)
src.watcher.Watch(nil)
}
}
ch := src.watcher.Bind()
@ -215,6 +216,7 @@ func NewFileWatcherSource(ctx context.Context, underlying Source) *FileWatcherSo
watcher: fileutil.NewWatcher(),
cfg: cfg,
}
context.AfterFunc(ctx, func() { src.watcher.Close() })
ch := src.watcher.Bind()
go func() {
@ -241,9 +243,9 @@ func (src *FileWatcherSource) GetConfig() *Config {
func (src *FileWatcherSource) onConfigChange(ctx context.Context, cfg *Config) {
// update the file watcher with paths from the config
if cfg.Options.IsRuntimeFlagSet(RuntimeFlagConfigHotReload) {
src.watcher.Watch(ctx, getAllConfigFilePaths(cfg))
src.watcher.Watch(getAllConfigFilePaths(cfg))
} else {
src.watcher.Watch(ctx, nil)
src.watcher.Watch(nil)
}
src.mu.Lock()

View file

@ -301,6 +301,7 @@ func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
if aud, ok := claims.GetAudience(); ok {
s.Audience = aud
}
s.RefreshDisabled = true
return s
}

View file

@ -371,10 +371,11 @@ func Test_newSessionFromIDPClaims(t *testing.T) {
"empty claims", "S1",
nil,
&session.Session{
Id: "S1",
AccessedAt: timestamppb.New(tm1),
ExpiresAt: timestamppb.New(tm1.Add(time.Hour * 14)),
IssuedAt: timestamppb.New(tm1),
Id: "S1",
AccessedAt: timestamppb.New(tm1),
ExpiresAt: timestamppb.New(tm1.Add(time.Hour * 14)),
IssuedAt: timestamppb.New(tm1),
RefreshDisabled: true,
},
},
{
@ -398,6 +399,7 @@ func Test_newSessionFromIDPClaims(t *testing.T) {
"iat": {tm2.Unix()},
"exp": {tm3.Unix()},
}.ToPB(),
RefreshDisabled: true,
},
},
} {
@ -490,6 +492,7 @@ func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "U1", s.GetUserId())
assert.Equal(t, "ACCESS_TOKEN", s.GetOauthToken().GetAccessToken())
assert.True(t, s.GetRefreshDisabled())
})
t.Run("identity_token", func(t *testing.T) {
t.Parallel()
@ -530,5 +533,6 @@ func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) {
s, err := c.CreateSession(ctx, cfg, route, req)
assert.NoError(t, err)
assert.Equal(t, "U1", s.GetUserId())
assert.True(t, s.GetRefreshDisabled())
})
}

View file

@ -27,12 +27,26 @@ func NewWatcher() *Watcher {
}
}
// Watch updates the watched file paths.
func (watcher *Watcher) Watch(ctx context.Context, filePaths []string) {
// Close closes the watcher.
func (watcher *Watcher) Close() error {
watcher.mu.Lock()
defer watcher.mu.Unlock()
watcher.initLocked(ctx)
var err error
if watcher.pollingWatcher != nil {
err = watcher.pollingWatcher.Close()
watcher.pollingWatcher = nil
}
return err
}
// Watch updates the watched file paths.
func (watcher *Watcher) Watch(filePaths []string) {
watcher.mu.Lock()
defer watcher.mu.Unlock()
watcher.initLocked()
var add []string
seen := map[string]struct{}{}
@ -56,7 +70,7 @@ func (watcher *Watcher) Watch(ctx context.Context, filePaths []string) {
if watcher.pollingWatcher != nil {
err := watcher.pollingWatcher.Add(filePath)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to add file to polling-based file watcher")
log.Error().Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to add file to polling-based file watcher")
}
}
}
@ -67,22 +81,19 @@ func (watcher *Watcher) Watch(ctx context.Context, filePaths []string) {
if watcher.pollingWatcher != nil {
err := watcher.pollingWatcher.Remove(filePath)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to remove file from polling-based file watcher")
log.Error().Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to remove file from polling-based file watcher")
}
}
}
}
func (watcher *Watcher) initLocked(ctx context.Context) {
func (watcher *Watcher) initLocked() {
if watcher.pollingWatcher != nil {
return
}
if watcher.pollingWatcher == nil {
watcher.pollingWatcher = filenotify.NewPollingWatcher(nil)
context.AfterFunc(ctx, func() {
watcher.pollingWatcher.Close()
})
}
errors := watcher.pollingWatcher.Errors()
@ -91,15 +102,15 @@ func (watcher *Watcher) initLocked(ctx context.Context) {
// log errors
go func() {
for err := range errors {
log.Ctx(ctx).Error().Err(err).Msg("fileutil/watcher: file notification error")
log.Error().Err(err).Msg("fileutil/watcher: file notification error")
}
}()
// handle events
go func() {
for evt := range events {
log.Ctx(ctx).Info().Str("name", evt.Name).Str("op", evt.Op.String()).Msg("fileutil/watcher: file notification event")
watcher.Broadcast(ctx)
log.Info().Str("name", evt.Name).Str("op", evt.Op.String()).Msg("fileutil/watcher: file notification event")
watcher.Broadcast(context.Background())
}
}()
}

View file

@ -18,7 +18,7 @@ func TestWatcher(t *testing.T) {
require.NoError(t, err)
w := NewWatcher()
w.Watch(context.Background(), []string{filepath.Join(tmpdir, "test1.txt")})
w.Watch([]string{filepath.Join(tmpdir, "test1.txt")})
ch := w.Bind()
t.Cleanup(func() { w.Unbind(ch) })
@ -41,7 +41,7 @@ func TestWatcherSymlink(t *testing.T) {
assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test1.txt"), filepath.Join(tmpdir, "symlink1.txt")))
w := NewWatcher()
w.Watch(context.Background(), []string{filepath.Join(tmpdir, "symlink1.txt")})
w.Watch([]string{filepath.Join(tmpdir, "symlink1.txt")})
ch := w.Bind()
t.Cleanup(func() { w.Unbind(ch) })
@ -63,7 +63,7 @@ func TestWatcher_FileRemoval(t *testing.T) {
require.NoError(t, err)
w := NewWatcher()
w.Watch(context.Background(), []string{filepath.Join(tmpdir, "test1.txt")})
w.Watch([]string{filepath.Join(tmpdir, "test1.txt")})
ch := w.Bind()
t.Cleanup(func() { w.Unbind(ch) })

View file

@ -189,6 +189,7 @@ type Session struct {
OauthToken *OAuthToken `protobuf:"bytes,7,opt,name=oauth_token,json=oauthToken,proto3" json:"oauth_token,omitempty"`
Claims map[string]*structpb.ListValue `protobuf:"bytes,9,rep,name=claims,proto3" json:"claims,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
Audience []string `protobuf:"bytes,10,rep,name=audience,proto3" json:"audience,omitempty"`
RefreshDisabled bool `protobuf:"varint,19,opt,name=refresh_disabled,json=refreshDisabled,proto3" json:"refresh_disabled,omitempty"`
ImpersonateSessionId *string `protobuf:"bytes,15,opt,name=impersonate_session_id,json=impersonateSessionId,proto3,oneof" json:"impersonate_session_id,omitempty"`
}
@ -301,6 +302,13 @@ func (x *Session) GetAudience() []string {
return nil
}
func (x *Session) GetRefreshDisabled() bool {
if x != nil {
return x.RefreshDisabled
}
return false
}
func (x *Session) GetImpersonateSessionId() string {
if x != nil && x.ImpersonateSessionId != nil {
return *x.ImpersonateSessionId
@ -430,7 +438,7 @@ var file_session_proto_rawDesc = []byte{
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72,
0x65, 0x73, 0x41, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f,
0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x66,
0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xbb, 0x06, 0x0a, 0x07, 0x53, 0x65,
0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xe6, 0x06, 0x0a, 0x07, 0x53, 0x65,
0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12,
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
@ -463,29 +471,32 @@ var file_session_proto_rawDesc = []byte{
0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x06, 0x63, 0x6c,
0x61, 0x69, 0x6d, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
0x12, 0x39, 0x0a, 0x16, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f,
0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x09,
0x48, 0x00, 0x52, 0x14, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x53,
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01, 0x0a, 0x10,
0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c,
0x12, 0x17, 0x0a, 0x07, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x06, 0x74, 0x79, 0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75, 0x6e, 0x61,
0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16,
0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69,
0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28,
0x09, 0x48, 0x00, 0x52, 0x02, 0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65,
0x6e, 0x74, 0x69, 0x61, 0x6c, 0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x45,
0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61, 0x6c, 0x75,
0x65, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19, 0x0a, 0x17,
0x5f, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73,
0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75,
0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70,
0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63,
0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x12, 0x29, 0x0a, 0x10, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x64, 0x69, 0x73, 0x61,
0x62, 0x6c, 0x65, 0x64, 0x18, 0x13, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x72, 0x65, 0x66, 0x72,
0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69,
0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69,
0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69,
0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01, 0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63,
0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74,
0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79,
0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61,
0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74,
0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65,
0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02,
0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c,
0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61,
0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19, 0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65,
0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f,
0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73,
0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (

View file

@ -42,6 +42,7 @@ message Session {
OAuthToken oauth_token = 7;
map<string, google.protobuf.ListValue> claims = 9;
repeated string audience = 10;
bool refresh_disabled = 19;
optional string impersonate_session_id = 15;
}

View file

@ -252,6 +252,11 @@ func (mgr *Manager) refreshSessionInternal(
return false
}
if s.GetRefreshDisabled() {
// refresh was explicitly disabled
return false
}
if s.Session == nil || s.Session.OauthToken == nil {
log.Ctx(ctx).Info().
Str("user_id", userID).

View file

@ -225,6 +225,11 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
return
}
if s.GetRefreshDisabled() {
// refresh was explicitly disabled
return
}
if s.GetOauthToken() == nil {
log.Ctx(ctx).Info().
Str("user_id", s.GetUserId()).