From 5b18527fee384f53ab85f68fdaf28f1043489da7 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 30 Dec 2020 08:00:39 -0700 Subject: [PATCH] tcptunnel: handle invalid http response codes (#1727) --- internal/cliutil/jwtcache.go | 20 ++++++++++++++++++++ internal/tcptunnel/tcptunnel.go | 20 ++++++++++++-------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/internal/cliutil/jwtcache.go b/internal/cliutil/jwtcache.go index 89b886220..71afd21a2 100644 --- a/internal/cliutil/jwtcache.go +++ b/internal/cliutil/jwtcache.go @@ -25,6 +25,7 @@ var ( // A JWTCache loads and stores JWTs. type JWTCache interface { + DeleteJWT(key string) error LoadJWT(key string) (rawJWT string, err error) StoreJWT(key string, rawJWT string) error } @@ -53,6 +54,16 @@ func NewLocalJWTCache() (*LocalJWTCache, error) { }, nil } +// DeleteJWT deletes a raw JWT from the local cache. +func (cache *LocalJWTCache) DeleteJWT(key string) error { + path := filepath.Join(cache.dir, cache.fileName(key)) + err := os.Remove(path) + if os.IsNotExist(err) { + err = nil + } + return err +} + // LoadJWT loads a raw JWT from the local cache. func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) { path := filepath.Join(cache.dir, cache.fileName(key)) @@ -98,6 +109,15 @@ func NewMemoryJWTCache() *MemoryJWTCache { return &MemoryJWTCache{entries: make(map[string]string)} } +// DeleteJWT deletes a JWT from the in-memory map. +func (cache *MemoryJWTCache) DeleteJWT(key string) error { + cache.mu.Lock() + defer cache.mu.Unlock() + + delete(cache.entries, key) + return nil +} + // LoadJWT loads a JWT from the in-memory map. func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) { cache.mu.Lock() diff --git a/internal/tcptunnel/tcptunnel.go b/internal/tcptunnel/tcptunnel.go index 7bbe9dfd7..32a226a0c 100644 --- a/internal/tcptunnel/tcptunnel.go +++ b/internal/tcptunnel/tcptunnel.go @@ -98,10 +98,10 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error { default: return fmt.Errorf("tcptunnel: failed to load JWT: %w", err) } - return tun.run(ctx, local, rawJWT) + return tun.run(ctx, local, rawJWT, 0) } -func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error { +func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error { log.Info(). Str("dst", tun.cfg.dstHost). Str("proxy", tun.cfg.proxyHost). @@ -160,15 +160,18 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) http.StatusFound, http.StatusTemporaryRedirect, http.StatusPermanentRedirect: - if rawJWT == "" { + if retryCount == 0 { _ = remote.Close() - authURL, err := url.Parse(res.Header.Get("Location")) - if err != nil { - return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err) + serverURL := &url.URL{ + Scheme: "http", + Host: tun.cfg.proxyHost, + } + if tun.cfg.tlsConfig != nil { + serverURL.Scheme = "https" } - rawJWT, err = tun.auth.GetJWT(ctx, authURL) + rawJWT, err = tun.auth.GetJWT(ctx, serverURL) if err != nil { return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err) } @@ -178,10 +181,11 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) return fmt.Errorf("tcptunnel: failed to store JWT: %w", err) } - return tun.run(ctx, local, rawJWT) + return tun.run(ctx, local, rawJWT, retryCount+1) } fallthrough default: + _ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey()) return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode) }