mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
tcptunnel: handle invalid http response codes (#1727)
This commit is contained in:
parent
e56e7e4b9e
commit
5b18527fee
2 changed files with 32 additions and 8 deletions
|
@ -25,6 +25,7 @@ var (
|
||||||
|
|
||||||
// A JWTCache loads and stores JWTs.
|
// A JWTCache loads and stores JWTs.
|
||||||
type JWTCache interface {
|
type JWTCache interface {
|
||||||
|
DeleteJWT(key string) error
|
||||||
LoadJWT(key string) (rawJWT string, err error)
|
LoadJWT(key string) (rawJWT string, err error)
|
||||||
StoreJWT(key string, rawJWT string) error
|
StoreJWT(key string, rawJWT string) error
|
||||||
}
|
}
|
||||||
|
@ -53,6 +54,16 @@ func NewLocalJWTCache() (*LocalJWTCache, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteJWT deletes a raw JWT from the local cache.
|
||||||
|
func (cache *LocalJWTCache) DeleteJWT(key string) error {
|
||||||
|
path := filepath.Join(cache.dir, cache.fileName(key))
|
||||||
|
err := os.Remove(path)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// LoadJWT loads a raw JWT from the local cache.
|
// LoadJWT loads a raw JWT from the local cache.
|
||||||
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
||||||
path := filepath.Join(cache.dir, cache.fileName(key))
|
path := filepath.Join(cache.dir, cache.fileName(key))
|
||||||
|
@ -98,6 +109,15 @@ func NewMemoryJWTCache() *MemoryJWTCache {
|
||||||
return &MemoryJWTCache{entries: make(map[string]string)}
|
return &MemoryJWTCache{entries: make(map[string]string)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteJWT deletes a JWT from the in-memory map.
|
||||||
|
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
|
||||||
|
delete(cache.entries, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadJWT loads a JWT from the in-memory map.
|
// LoadJWT loads a JWT from the in-memory map.
|
||||||
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
||||||
cache.mu.Lock()
|
cache.mu.Lock()
|
||||||
|
|
|
@ -98,10 +98,10 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
|
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
|
||||||
}
|
}
|
||||||
return tun.run(ctx, local, rawJWT)
|
return tun.run(ctx, local, rawJWT, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error {
|
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("dst", tun.cfg.dstHost).
|
Str("dst", tun.cfg.dstHost).
|
||||||
Str("proxy", tun.cfg.proxyHost).
|
Str("proxy", tun.cfg.proxyHost).
|
||||||
|
@ -160,15 +160,18 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string)
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
http.StatusTemporaryRedirect,
|
http.StatusTemporaryRedirect,
|
||||||
http.StatusPermanentRedirect:
|
http.StatusPermanentRedirect:
|
||||||
if rawJWT == "" {
|
if retryCount == 0 {
|
||||||
_ = remote.Close()
|
_ = remote.Close()
|
||||||
|
|
||||||
authURL, err := url.Parse(res.Header.Get("Location"))
|
serverURL := &url.URL{
|
||||||
if err != nil {
|
Scheme: "http",
|
||||||
return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err)
|
Host: tun.cfg.proxyHost,
|
||||||
|
}
|
||||||
|
if tun.cfg.tlsConfig != nil {
|
||||||
|
serverURL.Scheme = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
rawJWT, err = tun.auth.GetJWT(ctx, authURL)
|
rawJWT, err = tun.auth.GetJWT(ctx, serverURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
|
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -178,10 +181,11 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string)
|
||||||
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
|
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tun.run(ctx, local, rawJWT)
|
return tun.run(ctx, local, rawJWT, retryCount+1)
|
||||||
}
|
}
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
|
_ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey())
|
||||||
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
|
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue