mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
tcptunnel: handle invalid http response codes (#1727)
This commit is contained in:
parent
e56e7e4b9e
commit
5b18527fee
2 changed files with 32 additions and 8 deletions
|
@ -25,6 +25,7 @@ var (
|
|||
|
||||
// A JWTCache loads and stores JWTs.
|
||||
type JWTCache interface {
|
||||
DeleteJWT(key string) error
|
||||
LoadJWT(key string) (rawJWT string, err error)
|
||||
StoreJWT(key string, rawJWT string) error
|
||||
}
|
||||
|
@ -53,6 +54,16 @@ func NewLocalJWTCache() (*LocalJWTCache, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// DeleteJWT deletes a raw JWT from the local cache.
|
||||
func (cache *LocalJWTCache) DeleteJWT(key string) error {
|
||||
path := filepath.Join(cache.dir, cache.fileName(key))
|
||||
err := os.Remove(path)
|
||||
if os.IsNotExist(err) {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadJWT loads a raw JWT from the local cache.
|
||||
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
||||
path := filepath.Join(cache.dir, cache.fileName(key))
|
||||
|
@ -98,6 +109,15 @@ func NewMemoryJWTCache() *MemoryJWTCache {
|
|||
return &MemoryJWTCache{entries: make(map[string]string)}
|
||||
}
|
||||
|
||||
// DeleteJWT deletes a JWT from the in-memory map.
|
||||
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
delete(cache.entries, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadJWT loads a JWT from the in-memory map.
|
||||
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
||||
cache.mu.Lock()
|
||||
|
|
|
@ -98,10 +98,10 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
|
|||
default:
|
||||
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
|
||||
}
|
||||
return tun.run(ctx, local, rawJWT)
|
||||
return tun.run(ctx, local, rawJWT, 0)
|
||||
}
|
||||
|
||||
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error {
|
||||
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
|
||||
log.Info().
|
||||
Str("dst", tun.cfg.dstHost).
|
||||
Str("proxy", tun.cfg.proxyHost).
|
||||
|
@ -160,15 +160,18 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string)
|
|||
http.StatusFound,
|
||||
http.StatusTemporaryRedirect,
|
||||
http.StatusPermanentRedirect:
|
||||
if rawJWT == "" {
|
||||
if retryCount == 0 {
|
||||
_ = remote.Close()
|
||||
|
||||
authURL, err := url.Parse(res.Header.Get("Location"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err)
|
||||
serverURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: tun.cfg.proxyHost,
|
||||
}
|
||||
if tun.cfg.tlsConfig != nil {
|
||||
serverURL.Scheme = "https"
|
||||
}
|
||||
|
||||
rawJWT, err = tun.auth.GetJWT(ctx, authURL)
|
||||
rawJWT, err = tun.auth.GetJWT(ctx, serverURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
|
||||
}
|
||||
|
@ -178,10 +181,11 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string)
|
|||
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
|
||||
}
|
||||
|
||||
return tun.run(ctx, local, rawJWT)
|
||||
return tun.run(ctx, local, rawJWT, retryCount+1)
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
_ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey())
|
||||
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue