From 61ab4e4837a1460e45789d7c1e444908b43c0837 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 17 Dec 2020 12:37:28 -0700 Subject: [PATCH] TCP client command (#1696) * add cli commands * add jwt cache test * add tcptunnel test * add stdin/stdout support * use cryptutil hash function * doc updates * fix log timestamp --- cmd/pomerium-cli/tcp.go | 105 ++++++++++++ go.mod | 1 + go.sum | 2 + internal/authclient/authclient.go | 135 ++++++++++++++++ internal/authclient/authclient_test.go | 74 +++++++++ internal/authclient/config.go | 27 ++++ internal/cliutil/clitutil.go | 2 + internal/cliutil/jwtcache.go | 144 +++++++++++++++++ internal/cliutil/jwtcache_test.go | 69 ++++++++ internal/tcptunnel/config.go | 60 +++++++ internal/tcptunnel/tcptunnel.go | 213 +++++++++++++++++++++++++ internal/tcptunnel/tcptunnel_test.go | 91 +++++++++++ 12 files changed, 923 insertions(+) create mode 100644 cmd/pomerium-cli/tcp.go create mode 100644 internal/authclient/authclient.go create mode 100644 internal/authclient/authclient_test.go create mode 100644 internal/authclient/config.go create mode 100644 internal/cliutil/clitutil.go create mode 100644 internal/cliutil/jwtcache.go create mode 100644 internal/cliutil/jwtcache_test.go create mode 100644 internal/tcptunnel/config.go create mode 100644 internal/tcptunnel/tcptunnel.go create mode 100644 internal/tcptunnel/tcptunnel_test.go diff --git a/cmd/pomerium-cli/tcp.go b/cmd/pomerium-cli/tcp.go new file mode 100644 index 000000000..f44794006 --- /dev/null +++ b/cmd/pomerium-cli/tcp.go @@ -0,0 +1,105 @@ +package main + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/url" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/tcptunnel" +) + +var tcpCmdOptions struct { + listen string + pomeriumURL string +} + +func init() { + flags := tcpCmd.Flags() + flags.StringVar(&tcpCmdOptions.listen, "listen", "127.0.0.1:0", + "local address to start a listener on") + flags.StringVar(&tcpCmdOptions.pomeriumURL, "pomerium-url", "", + "the URL of the pomerium server to connect to") + rootCmd.AddCommand(tcpCmd) +} + +var tcpCmd = &cobra.Command{ + Use: "tcp destination", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + dstHost := args[0] + dstHostname, _, err := net.SplitHostPort(dstHost) + if err != nil { + return fmt.Errorf("invalid destination: %w", err) + } + + pomeriumURL := &url.URL{ + Scheme: "https", + Host: net.JoinHostPort(dstHostname, "443"), + } + if tcpCmdOptions.pomeriumURL != "" { + pomeriumURL, err = url.Parse(tcpCmdOptions.pomeriumURL) + if err != nil { + return fmt.Errorf("invalid pomerium URL: %w", err) + } + if !strings.Contains(pomeriumURL.Host, ":") { + if pomeriumURL.Scheme == "https" { + pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "443") + } else { + pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "80") + } + } + } + + var tlsConfig *tls.Config + if pomeriumURL.Scheme == "https" { + tlsConfig = new(tls.Config) + } + + l := zerolog.New(zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) { + w.Out = os.Stderr + })).With().Timestamp().Logger() + log.SetLogger(&l) + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-c + cancel() + }() + + tun := tcptunnel.New( + tcptunnel.WithDestinationHost(dstHost), + tcptunnel.WithProxyHost(pomeriumURL.Host), + tcptunnel.WithTLSConfig(tlsConfig), + ) + + if tcpCmdOptions.listen == "-" { + err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout}) + } else { + err = tun.RunListener(ctx, tcpCmdOptions.listen) + } + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + os.Exit(1) + } + + return nil + }, +} + +type readWriter struct { + io.Reader + io.Writer +} diff --git a/go.mod b/go.mod index c9f738150..f85b8b983 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect github.com/lithammer/shortuuid/v3 v3.0.5 + github.com/martinlindhe/base36 v1.1.0 github.com/mitchellh/hashstructure/v2 v2.0.1 github.com/natefinch/atomic v0.0.0-20200526193002-18c0533a5b09 github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce diff --git a/go.sum b/go.sum index 6c7c4c015..b4a5ff377 100644 --- a/go.sum +++ b/go.sum @@ -380,6 +380,8 @@ github.com/lithammer/shortuuid/v3 v3.0.5/go.mod h1:2QdoCtD4SBzugx2qs3gdR3LXY6Mcx github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/martinlindhe/base36 v1.1.0 h1:cIwvvwYse/0+1CkUPYH5ZvVIYG3JrILmQEIbLuar02Y= +github.com/martinlindhe/base36 v1.1.0/go.mod h1:+AtEs8xrBpCeYgSLoY/aJ6Wf37jtBuR0s35750M27+8= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= diff --git a/internal/authclient/authclient.go b/internal/authclient/authclient.go new file mode 100644 index 000000000..d47399963 --- /dev/null +++ b/internal/authclient/authclient.go @@ -0,0 +1,135 @@ +// Package authclient contains a CLI authentication client for Pomerium. +package authclient + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "time" + + "github.com/skratchdot/open-golang/open" + "golang.org/x/sync/errgroup" +) + +var openBrowser = open.Run + +// An AuthClient retrieves an authentication JWT via the Pomerium login API. +type AuthClient struct { + cfg *config +} + +// New creates a new AuthClient. +func New(options ...Option) *AuthClient { + return &AuthClient{ + cfg: getConfig(options...), + } +} + +// GetJWT retrieves a JWT from Pomerium. +func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL) (rawJWT string, err error) { + li, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", fmt.Errorf("failed to start listener: %w", err) + } + defer func() { _ = li.Close() }() + + incomingJWT := make(chan string) + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return client.runHTTPServer(ctx, li, incomingJWT) + }) + eg.Go(func() error { + return client.runOpenBrowser(ctx, li, serverURL) + }) + eg.Go(func() error { + select { + case rawJWT = <-incomingJWT: + case <-ctx.Done(): + return ctx.Err() + } + return nil + }) + err = eg.Wait() + if err != nil { + return "", err + } + + return rawJWT, nil +} + +func (client *AuthClient) runHTTPServer(ctx context.Context, li net.Listener, incomingJWT chan string) error { + var srv *http.Server + srv = &http.Server{ + BaseContext: func(li net.Listener) context.Context { + return ctx + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jwt := r.FormValue("pomerium_jwt") + if jwt == "" { + http.Error(w, "not found", http.StatusNotFound) + return + } + incomingJWT <- jwt + + w.Header().Set("Content-Type", "text/plain") + _, _ = io.WriteString(w, "login complete, you may close this page") + + go func() { _ = srv.Shutdown(ctx) }() + }), + } + // shutdown the server when ctx is done. + go func() { + <-ctx.Done() + _ = srv.Shutdown(ctx) + }() + err := srv.Serve(li) + if err == http.ErrServerClosed { + err = nil + } + return err +} + +func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error { + dst := serverURL.ResolveReference(&url.URL{ + Path: "/.pomerium/api/v1/login", + RawQuery: url.Values{ + "pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())}, + }.Encode(), + }) + + ctx, clearTimeout := context.WithTimeout(ctx, 10*time.Second) + defer clearTimeout() + + req, err := http.NewRequestWithContext(ctx, "GET", dst.String(), nil) + if err != nil { + return err + } + + transport := &http.Transport{ + TLSClientConfig: client.cfg.tlsConfig, + } + hc := &http.Client{ + Transport: transport, + } + + res, err := hc.Do(req) + if err != nil { + return fmt.Errorf("failed to get login url: %w", err) + } + defer func() { _ = res.Body.Close() }() + + if res.StatusCode/100 != 2 { + return fmt.Errorf("failed to get login url: %s", res.Status) + } + + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("failed to read login url: %w", err) + } + + return openBrowser(string(bs)) +} diff --git a/internal/authclient/authclient_test.go b/internal/authclient/authclient_test.go new file mode 100644 index 000000000..e3061aa40 --- /dev/null +++ b/internal/authclient/authclient_test.go @@ -0,0 +1,74 @@ +package authclient + +import ( + "context" + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/stretchr/testify/assert" +) + +func TestAuthClient(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30) + defer clearTimeout() + + li, err := net.Listen("tcp", "127.0.0.1:0") + if !assert.NoError(t, err) { + return + } + defer func() { _ = li.Close() }() + + go func() { + h := chi.NewMux() + h.Get("/.pomerium/api/v1/login", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(r.FormValue("pomerium_redirect_uri"))) + }) + srv := &http.Server{ + BaseContext: func(li net.Listener) context.Context { + return ctx + }, + Handler: h, + } + _ = srv.Serve(li) + }() + + origOpenBrowser := openBrowser + defer func() { + openBrowser = origOpenBrowser + }() + openBrowser = func(input string) error { + u, err := url.Parse(input) + if err != nil { + return err + } + u = u.ResolveReference(&url.URL{ + RawQuery: url.Values{ + "pomerium_jwt": {"TEST"}, + }.Encode(), + }) + + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + _ = res.Body.Close() + return nil + } + + ac := New() + rawJWT, err := ac.GetJWT(ctx, &url.URL{ + Scheme: "http", + Host: li.Addr().String(), + }) + assert.NoError(t, err) + assert.Equal(t, "TEST", rawJWT) +} diff --git a/internal/authclient/config.go b/internal/authclient/config.go new file mode 100644 index 000000000..f61d5674f --- /dev/null +++ b/internal/authclient/config.go @@ -0,0 +1,27 @@ +package authclient + +import ( + "crypto/tls" +) + +type config struct { + tlsConfig *tls.Config +} + +func getConfig(options ...Option) *config { + cfg := new(config) + for _, o := range options { + o(cfg) + } + return cfg +} + +// An Option modifies the config. +type Option func(*config) + +// WithTLSConfig returns an option to configure the tls config. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(cfg *config) { + cfg.tlsConfig = tlsConfig + } +} diff --git a/internal/cliutil/clitutil.go b/internal/cliutil/clitutil.go new file mode 100644 index 000000000..e59e63316 --- /dev/null +++ b/internal/cliutil/clitutil.go @@ -0,0 +1,2 @@ +// Package cliutil contains functionality related to CLI apps. +package cliutil diff --git a/internal/cliutil/jwtcache.go b/internal/cliutil/jwtcache.go new file mode 100644 index 000000000..89b886220 --- /dev/null +++ b/internal/cliutil/jwtcache.go @@ -0,0 +1,144 @@ +package cliutil + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sync" + "time" + + "github.com/martinlindhe/base36" + "gopkg.in/square/go-jose.v2" + + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +// predefined cache errors +var ( + ErrExpired = errors.New("expired") + ErrInvalid = errors.New("invalid") + ErrNotFound = errors.New("not found") +) + +// A JWTCache loads and stores JWTs. +type JWTCache interface { + LoadJWT(key string) (rawJWT string, err error) + StoreJWT(key string, rawJWT string) error +} + +// A LocalJWTCache stores files in the user's cache directory. +type LocalJWTCache struct { + dir string +} + +// NewLocalJWTCache creates a new LocalJWTCache. +func NewLocalJWTCache() (*LocalJWTCache, error) { + root, err := os.UserCacheDir() + if err != nil { + return nil, err + } + + dir := filepath.Join(root, "pomerium-cli", "jwts") + + err = os.MkdirAll(dir, 0755) + if err != nil { + return nil, fmt.Errorf("error creating user cache directory: %w", err) + } + + return &LocalJWTCache{ + dir: dir, + }, nil +} + +// LoadJWT loads a raw JWT from the local cache. +func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) { + path := filepath.Join(cache.dir, cache.fileName(key)) + rawBS, err := ioutil.ReadFile(path) + if os.IsNotExist(err) { + return "", ErrNotFound + } else if err != nil { + return "", err + } + rawJWT = string(rawBS) + + return rawJWT, checkExpiry(rawJWT) +} + +// StoreJWT stores a raw JWT in the local cache. +func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error { + path := filepath.Join(cache.dir, cache.fileName(key)) + err := ioutil.WriteFile(path, []byte(rawJWT), 0600) + if err != nil { + return err + } + + return nil +} + +func (cache *LocalJWTCache) hash(str string) string { + h := cryptutil.Hash("LocalJWTCache", []byte(str)) + return base36.EncodeBytes(h) +} + +func (cache *LocalJWTCache) fileName(key string) string { + return cache.hash(key) + ".jwt" +} + +// A MemoryJWTCache stores JWTs in an in-memory map. +type MemoryJWTCache struct { + mu sync.Mutex + entries map[string]string +} + +// NewMemoryJWTCache creates a new in-memory JWT cache. +func NewMemoryJWTCache() *MemoryJWTCache { + return &MemoryJWTCache{entries: make(map[string]string)} +} + +// LoadJWT loads a JWT from the in-memory map. +func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + rawJWT, ok := cache.entries[key] + if !ok { + return "", ErrNotFound + } + + return rawJWT, checkExpiry(rawJWT) +} + +// StoreJWT stores a JWT in the in-memory map. +func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error { + cache.mu.Lock() + defer cache.mu.Unlock() + + cache.entries[key] = rawJWT + + return nil +} + +func checkExpiry(rawJWT string) error { + tok, err := jose.ParseSigned(rawJWT) + if err != nil { + return ErrInvalid + } + + var claims struct { + Expiry int64 `json:"exp"` + } + err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims) + if err != nil { + return ErrInvalid + } + + expiresAt := time.Unix(claims.Expiry, 0) + if expiresAt.Before(time.Now()) { + return ErrExpired + } + + return nil +} diff --git a/internal/cliutil/jwtcache_test.go b/internal/cliutil/jwtcache_test.go new file mode 100644 index 000000000..984adbbdd --- /dev/null +++ b/internal/cliutil/jwtcache_test.go @@ -0,0 +1,69 @@ +package cliutil + +import ( + "crypto/rand" + "crypto/rsa" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" +) + +func TestLocalJWTCache(t *testing.T) { + c := &LocalJWTCache{ + dir: filepath.Join(os.TempDir(), uuid.New().String()), + } + + err := os.MkdirAll(c.dir, 0755) + if !assert.NoError(t, err) { + return + } + defer func() { _ = os.RemoveAll(c.dir) }() + + t.Run("NotFound", func(t *testing.T) { + _, err := c.LoadJWT("NOTFOUND") + assert.Equal(t, ErrNotFound, err) + }) + t.Run("Invalid", func(t *testing.T) { + err := c.StoreJWT("INVALID", "INVALID") + if !assert.NoError(t, err) { + return + } + _, err = c.LoadJWT("INVALID") + assert.Equal(t, ErrInvalid, err) + }) + t.Run("Expired", func(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if !assert.NoError(t, err) { + return + } + + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS512, Key: privateKey}, nil) + if !assert.NoError(t, err) { + return + } + + object, err := signer.Sign([]byte(`{"exp": ` + fmt.Sprint(time.Now().Add(-time.Second).Unix()) + `}`)) + if !assert.NoError(t, err) { + return + } + + rawJWT, err := object.CompactSerialize() + if !assert.NoError(t, err) { + return + } + + err = c.StoreJWT("EXPIRED", rawJWT) + if !assert.NoError(t, err) { + return + } + + _, err = c.LoadJWT("EXPIRED") + assert.Equal(t, ErrExpired, err) + }) +} diff --git a/internal/tcptunnel/config.go b/internal/tcptunnel/config.go new file mode 100644 index 000000000..e5a0f7052 --- /dev/null +++ b/internal/tcptunnel/config.go @@ -0,0 +1,60 @@ +package tcptunnel + +import ( + "crypto/tls" + + "github.com/pomerium/pomerium/internal/cliutil" + "github.com/pomerium/pomerium/internal/log" +) + +type config struct { + jwtCache cliutil.JWTCache + dstHost string + proxyHost string + tlsConfig *tls.Config +} + +func getConfig(options ...Option) *config { + cfg := new(config) + if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil { + WithJWTCache(jwtCache)(cfg) + } else { + log.Error().Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache") + WithJWTCache(cliutil.NewMemoryJWTCache())(cfg) + } + for _, o := range options { + o(cfg) + } + return cfg +} + +// An Option modifies the config. +type Option func(*config) + +// WithDestinationHost returns an option to configure the destination host. +func WithDestinationHost(dstHost string) Option { + return func(cfg *config) { + cfg.dstHost = dstHost + } +} + +// WithJWTCache returns an option to configure the jwt cache. +func WithJWTCache(jwtCache cliutil.JWTCache) Option { + return func(cfg *config) { + cfg.jwtCache = jwtCache + } +} + +// WithProxyHost returns an option to configure the proxy host. +func WithProxyHost(proxyHost string) Option { + return func(cfg *config) { + cfg.proxyHost = proxyHost + } +} + +// WithTLSConfig returns an option to configure the tls config. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(cfg *config) { + cfg.tlsConfig = tlsConfig + } +} diff --git a/internal/tcptunnel/tcptunnel.go b/internal/tcptunnel/tcptunnel.go new file mode 100644 index 000000000..7bbe9dfd7 --- /dev/null +++ b/internal/tcptunnel/tcptunnel.go @@ -0,0 +1,213 @@ +// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect. +package tcptunnel + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "time" + + "github.com/pomerium/pomerium/internal/authclient" + "github.com/pomerium/pomerium/internal/cliutil" + "github.com/pomerium/pomerium/internal/log" + + backoff "github.com/cenkalti/backoff/v4" +) + +// A Tunnel represents a TCP tunnel over HTTP Connect. +type Tunnel struct { + cfg *config + auth *authclient.AuthClient +} + +// New creates a new Tunnel. +func New(options ...Option) *Tunnel { + cfg := getConfig(options...) + return &Tunnel{ + cfg: cfg, + auth: authclient.New(authclient.WithTLSConfig(cfg.tlsConfig)), + } +} + +// RunListener runs a network listener on the given address. For each +// incoming connection a new TCP tunnel is established via Run. +func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) error { + li, err := net.Listen("tcp", listenerAddress) + if err != nil { + return err + } + defer func() { _ = li.Close() }() + log.Info().Msg("tcptunnel: listening on " + li.Addr().String()) + + go func() { + <-ctx.Done() + _ = li.Close() + }() + + bo := backoff.NewExponentialBackOff() + bo.MaxElapsedTime = 0 + + for { + conn, err := li.Accept() + if err != nil { + // canceled, so ignore the error and return + if ctx.Err() != nil { + return nil + } + + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + log.Warn().Err(err).Msg("tcptunnel: temporarily failed to accept local connection") + select { + case <-time.After(bo.NextBackOff()): + case <-ctx.Done(): + return ctx.Err() + } + continue + } + return err + } + bo.Reset() + + go func() { + defer func() { _ = conn.Close() }() + + err := tun.Run(ctx, conn) + if err != nil { + log.Error().Err(err).Msg("tcptunnel: error serving local connection") + } + }() + } +} + +// Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local. +func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error { + rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey()) + switch { + // if there is no error, or it is one of the pre-defined cliutil errors, + // then ignore and use an empty JWT + case err == nil, + errors.Is(err, cliutil.ErrExpired), + errors.Is(err, cliutil.ErrInvalid), + errors.Is(err, cliutil.ErrNotFound): + default: + return fmt.Errorf("tcptunnel: failed to load JWT: %w", err) + } + return tun.run(ctx, local, rawJWT) +} + +func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error { + log.Info(). + Str("dst", tun.cfg.dstHost). + Str("proxy", tun.cfg.proxyHost). + Bool("secure", tun.cfg.tlsConfig != nil). + Msg("tcptunnel: opening connection") + + hdr := http.Header{} + if rawJWT != "" { + hdr.Set("Authorization", "Pomerium "+rawJWT) + } + + req := (&http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: tun.cfg.dstHost}, + Host: tun.cfg.dstHost, + Header: hdr, + }).WithContext(ctx) + + var remote net.Conn + var err error + if tun.cfg.tlsConfig != nil { + remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost) + } else { + remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost) + } + if err != nil { + return fmt.Errorf("tcptunnel: failed to establish connection to proxy: %w", err) + } + defer func() { + _ = remote.Close() + log.Info().Msg("tcptunnel: connection closed") + }() + if done := ctx.Done(); done != nil { + go func() { + <-done + _ = remote.Close() + }() + } + + err = req.Write(remote) + if err != nil { + return err + } + + br := bufio.NewReader(remote) + res, err := http.ReadResponse(br, req) + if err != nil { + return fmt.Errorf("tcptunnel: failed to read HTTP response: %w", err) + } + defer func() { + _ = res.Body.Close() + }() + switch res.StatusCode { + case http.StatusOK: + case http.StatusMovedPermanently, + http.StatusFound, + http.StatusTemporaryRedirect, + http.StatusPermanentRedirect: + if rawJWT == "" { + _ = remote.Close() + + authURL, err := url.Parse(res.Header.Get("Location")) + if err != nil { + return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err) + } + + rawJWT, err = tun.auth.GetJWT(ctx, authURL) + if err != nil { + return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err) + } + + err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT) + if err != nil { + return fmt.Errorf("tcptunnel: failed to store JWT: %w", err) + } + + return tun.run(ctx, local, rawJWT) + } + fallthrough + default: + return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode) + } + + log.Info().Msg("tcptunnel: connection established") + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(remote, local) + errc <- err + }() + go func() { + _, err := io.Copy(local, remote) + errc <- err + }() + + select { + case err := <-errc: + if err != nil { + err = fmt.Errorf("tcptunnel: %w", err) + } + return err + case <-ctx.Done(): + return nil + } +} + +func (tun *Tunnel) jwtCacheKey() string { + return fmt.Sprintf("%s|%s|%v", tun.cfg.dstHost, tun.cfg.proxyHost, tun.cfg.tlsConfig != nil) +} diff --git a/internal/tcptunnel/tcptunnel_test.go b/internal/tcptunnel/tcptunnel_test.go new file mode 100644 index 000000000..f4ab461f0 --- /dev/null +++ b/internal/tcptunnel/tcptunnel_test.go @@ -0,0 +1,91 @@ +package tcptunnel + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTunnel(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + backend, err := net.Listen("tcp", "127.0.0.1:0") + if !assert.NoError(t, err) { + return + } + defer func() { _ = backend.Close() }() + + go func() { + for { + conn, err := backend.Accept() + if err != nil { + return + } + go func() { + defer func() { _ = conn.Close() }() + + ln, _, _ := bufio.NewReader(conn).ReadLine() + assert.Equal(t, "HELLO WORLD", string(ln)) + }() + } + }() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !assert.Equal(t, "CONNECT", r.Method) { + return + } + if !assert.Equal(t, "example.com:9999", r.RequestURI) { + return + } + + w.WriteHeader(200) + + in, _, err := w.(http.Hijacker).Hijack() + if !assert.NoError(t, err) { + return + } + defer func() { _ = in.Close() }() + + out, err := net.Dial("tcp", backend.Addr().String()) + if !assert.NoError(t, err) { + return + } + defer func() { _ = out.Close() }() + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(in, out) + errc <- err + }() + go func() { + _, err := io.Copy(out, in) + errc <- err + }() + <-errc + })) + defer srv.Close() + + var buf bytes.Buffer + tun := New( + WithDestinationHost("example.com:9999"), + WithProxyHost(srv.Listener.Addr().String())) + err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf}) + if !assert.NoError(t, err) { + return + } +} + +type readWriter struct { + io.Reader + io.Writer +}