tcptunnel: handle invalid http response codes (#1727)

This commit is contained in:
Caleb Doxsey 2020-12-30 08:00:39 -07:00 committed by GitHub
parent e56e7e4b9e
commit 5b18527fee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 8 deletions

View file

@ -25,6 +25,7 @@ var (
// A JWTCache loads and stores JWTs. // A JWTCache loads and stores JWTs.
type JWTCache interface { type JWTCache interface {
DeleteJWT(key string) error
LoadJWT(key string) (rawJWT string, err error) LoadJWT(key string) (rawJWT string, err error)
StoreJWT(key string, rawJWT string) error StoreJWT(key string, rawJWT string) error
} }
@ -53,6 +54,16 @@ func NewLocalJWTCache() (*LocalJWTCache, error) {
}, nil }, nil
} }
// DeleteJWT deletes a raw JWT from the local cache.
func (cache *LocalJWTCache) DeleteJWT(key string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := os.Remove(path)
if os.IsNotExist(err) {
err = nil
}
return err
}
// LoadJWT loads a raw JWT from the local cache. // LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) { func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
path := filepath.Join(cache.dir, cache.fileName(key)) path := filepath.Join(cache.dir, cache.fileName(key))
@ -98,6 +109,15 @@ func NewMemoryJWTCache() *MemoryJWTCache {
return &MemoryJWTCache{entries: make(map[string]string)} return &MemoryJWTCache{entries: make(map[string]string)}
} }
// DeleteJWT deletes a JWT from the in-memory map.
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
delete(cache.entries, key)
return nil
}
// LoadJWT loads a JWT from the in-memory map. // LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) { func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
cache.mu.Lock() cache.mu.Lock()

View file

@ -98,10 +98,10 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
default: default:
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err) return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
} }
return tun.run(ctx, local, rawJWT) return tun.run(ctx, local, rawJWT, 0)
} }
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error { func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
log.Info(). log.Info().
Str("dst", tun.cfg.dstHost). Str("dst", tun.cfg.dstHost).
Str("proxy", tun.cfg.proxyHost). Str("proxy", tun.cfg.proxyHost).
@ -160,15 +160,18 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string)
http.StatusFound, http.StatusFound,
http.StatusTemporaryRedirect, http.StatusTemporaryRedirect,
http.StatusPermanentRedirect: http.StatusPermanentRedirect:
if rawJWT == "" { if retryCount == 0 {
_ = remote.Close() _ = remote.Close()
authURL, err := url.Parse(res.Header.Get("Location")) serverURL := &url.URL{
if err != nil { Scheme: "http",
return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err) Host: tun.cfg.proxyHost,
}
if tun.cfg.tlsConfig != nil {
serverURL.Scheme = "https"
} }
rawJWT, err = tun.auth.GetJWT(ctx, authURL) rawJWT, err = tun.auth.GetJWT(ctx, serverURL)
if err != nil { if err != nil {
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err) return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
} }
@ -178,10 +181,11 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string)
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err) return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
} }
return tun.run(ctx, local, rawJWT) return tun.run(ctx, local, rawJWT, retryCount+1)
} }
fallthrough fallthrough
default: default:
_ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey())
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode) return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
} }