// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect. package tcptunnel import ( "bufio" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "net/url" "time" "github.com/pomerium/pomerium/internal/authclient" "github.com/pomerium/pomerium/internal/cliutil" "github.com/pomerium/pomerium/internal/log" backoff "github.com/cenkalti/backoff/v4" ) // A Tunnel represents a TCP tunnel over HTTP Connect. type Tunnel struct { cfg *config auth *authclient.AuthClient } // New creates a new Tunnel. func New(options ...Option) *Tunnel { cfg := getConfig(options...) return &Tunnel{ cfg: cfg, auth: authclient.New( authclient.WithBrowserCommand(cfg.browserConfig), authclient.WithTLSConfig(cfg.tlsConfig)), } } // RunListener runs a network listener on the given address. For each // incoming connection a new TCP tunnel is established via Run. func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) error { li, err := net.Listen("tcp", listenerAddress) if err != nil { return err } defer func() { _ = li.Close() }() log.Info(ctx).Msg("tcptunnel: listening on " + li.Addr().String()) go func() { <-ctx.Done() _ = li.Close() }() bo := backoff.NewExponentialBackOff() bo.MaxElapsedTime = 0 for { conn, err := li.Accept() if err != nil { // canceled, so ignore the error and return if ctx.Err() != nil { return nil } if nerr, ok := err.(net.Error); ok && nerr.Temporary() { log.Warn(ctx).Err(err).Msg("tcptunnel: temporarily failed to accept local connection") select { case <-time.After(bo.NextBackOff()): case <-ctx.Done(): return ctx.Err() } continue } return err } bo.Reset() go func() { defer func() { _ = conn.Close() }() err := tun.Run(ctx, conn) if err != nil { log.Error(ctx).Err(err).Msg("tcptunnel: error serving local connection") } }() } } // Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local. func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error { rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey()) switch { // if there is no error, or it is one of the pre-defined cliutil errors, // then ignore and use an empty JWT case err == nil, errors.Is(err, cliutil.ErrExpired), errors.Is(err, cliutil.ErrInvalid), errors.Is(err, cliutil.ErrNotFound): default: return fmt.Errorf("tcptunnel: failed to load JWT: %w", err) } return tun.run(ctx, local, rawJWT, 0) } func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error { log.Info(ctx). Str("dst", tun.cfg.dstHost). Str("proxy", tun.cfg.proxyHost). Bool("secure", tun.cfg.tlsConfig != nil). Msg("tcptunnel: opening connection") hdr := http.Header{} if rawJWT != "" { hdr.Set("Authorization", "Pomerium "+rawJWT) } req := (&http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: tun.cfg.dstHost}, Host: tun.cfg.dstHost, Header: hdr, }).WithContext(ctx) var remote net.Conn var err error if tun.cfg.tlsConfig != nil { remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost) } else { remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost) } if err != nil { return fmt.Errorf("tcptunnel: failed to establish connection to proxy: %w", err) } defer func() { _ = remote.Close() log.Info(ctx).Msg("tcptunnel: connection closed") }() if done := ctx.Done(); done != nil { go func() { <-done _ = remote.Close() }() } err = req.Write(remote) if err != nil { return err } br := bufio.NewReader(remote) res, err := http.ReadResponse(br, req) if err != nil { return fmt.Errorf("tcptunnel: failed to read HTTP response: %w", err) } defer func() { _ = res.Body.Close() }() switch res.StatusCode { case http.StatusOK: case http.StatusMovedPermanently, http.StatusFound, http.StatusTemporaryRedirect, http.StatusPermanentRedirect: if retryCount == 0 { _ = remote.Close() serverURL := &url.URL{ Scheme: "http", Host: tun.cfg.proxyHost, } if tun.cfg.tlsConfig != nil { serverURL.Scheme = "https" } rawJWT, err = tun.auth.GetJWT(ctx, serverURL) if err != nil { return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err) } err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT) if err != nil { return fmt.Errorf("tcptunnel: failed to store JWT: %w", err) } return tun.run(ctx, local, rawJWT, retryCount+1) } fallthrough default: _ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey()) return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode) } log.Info(ctx).Msg("tcptunnel: connection established") errc := make(chan error, 2) go func() { _, err := io.Copy(remote, local) errc <- err }() remoteReader := deBuffer(br, remote) go func() { _, err := io.Copy(local, remoteReader) errc <- err }() select { case err := <-errc: if err != nil { err = fmt.Errorf("tcptunnel: %w", err) } return err case <-ctx.Done(): return nil } } func (tun *Tunnel) jwtCacheKey() string { return fmt.Sprintf("%s|%v", tun.cfg.proxyHost, tun.cfg.tlsConfig != nil) } func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader { if br.Buffered() == 0 { return underlying } return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying) }