mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-17 11:07:18 +02:00
TCP client command (#1696)
* add cli commands * add jwt cache test * add tcptunnel test * add stdin/stdout support * use cryptutil hash function * doc updates * fix log timestamp
This commit is contained in:
parent
4fbbf28a16
commit
61ab4e4837
12 changed files with 923 additions and 0 deletions
105
cmd/pomerium-cli/tcp.go
Normal file
105
cmd/pomerium-cli/tcp.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/tcptunnel"
|
||||
)
|
||||
|
||||
var tcpCmdOptions struct {
|
||||
listen string
|
||||
pomeriumURL string
|
||||
}
|
||||
|
||||
func init() {
|
||||
flags := tcpCmd.Flags()
|
||||
flags.StringVar(&tcpCmdOptions.listen, "listen", "127.0.0.1:0",
|
||||
"local address to start a listener on")
|
||||
flags.StringVar(&tcpCmdOptions.pomeriumURL, "pomerium-url", "",
|
||||
"the URL of the pomerium server to connect to")
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
}
|
||||
|
||||
var tcpCmd = &cobra.Command{
|
||||
Use: "tcp destination",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
dstHost := args[0]
|
||||
dstHostname, _, err := net.SplitHostPort(dstHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination: %w", err)
|
||||
}
|
||||
|
||||
pomeriumURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(dstHostname, "443"),
|
||||
}
|
||||
if tcpCmdOptions.pomeriumURL != "" {
|
||||
pomeriumURL, err = url.Parse(tcpCmdOptions.pomeriumURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid pomerium URL: %w", err)
|
||||
}
|
||||
if !strings.Contains(pomeriumURL.Host, ":") {
|
||||
if pomeriumURL.Scheme == "https" {
|
||||
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "443")
|
||||
} else {
|
||||
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "80")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if pomeriumURL.Scheme == "https" {
|
||||
tlsConfig = new(tls.Config)
|
||||
}
|
||||
|
||||
l := zerolog.New(zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.Out = os.Stderr
|
||||
})).With().Timestamp().Logger()
|
||||
log.SetLogger(&l)
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-c
|
||||
cancel()
|
||||
}()
|
||||
|
||||
tun := tcptunnel.New(
|
||||
tcptunnel.WithDestinationHost(dstHost),
|
||||
tcptunnel.WithProxyHost(pomeriumURL.Host),
|
||||
tcptunnel.WithTLSConfig(tlsConfig),
|
||||
)
|
||||
|
||||
if tcpCmdOptions.listen == "-" {
|
||||
err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout})
|
||||
} else {
|
||||
err = tun.RunListener(ctx, tcpCmdOptions.listen)
|
||||
}
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "%s\n", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
type readWriter struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
1
go.mod
1
go.mod
|
@ -32,6 +32,7 @@ require (
|
|||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
||||
github.com/lithammer/shortuuid/v3 v3.0.5
|
||||
github.com/martinlindhe/base36 v1.1.0
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.1
|
||||
github.com/natefinch/atomic v0.0.0-20200526193002-18c0533a5b09
|
||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||
|
|
2
go.sum
2
go.sum
|
@ -380,6 +380,8 @@ github.com/lithammer/shortuuid/v3 v3.0.5/go.mod h1:2QdoCtD4SBzugx2qs3gdR3LXY6Mcx
|
|||
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=
|
||||
github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4=
|
||||
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/martinlindhe/base36 v1.1.0 h1:cIwvvwYse/0+1CkUPYH5ZvVIYG3JrILmQEIbLuar02Y=
|
||||
github.com/martinlindhe/base36 v1.1.0/go.mod h1:+AtEs8xrBpCeYgSLoY/aJ6Wf37jtBuR0s35750M27+8=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
|
||||
|
|
135
internal/authclient/authclient.go
Normal file
135
internal/authclient/authclient.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
// Package authclient contains a CLI authentication client for Pomerium.
|
||||
package authclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var openBrowser = open.Run
|
||||
|
||||
// An AuthClient retrieves an authentication JWT via the Pomerium login API.
|
||||
type AuthClient struct {
|
||||
cfg *config
|
||||
}
|
||||
|
||||
// New creates a new AuthClient.
|
||||
func New(options ...Option) *AuthClient {
|
||||
return &AuthClient{
|
||||
cfg: getConfig(options...),
|
||||
}
|
||||
}
|
||||
|
||||
// GetJWT retrieves a JWT from Pomerium.
|
||||
func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL) (rawJWT string, err error) {
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to start listener: %w", err)
|
||||
}
|
||||
defer func() { _ = li.Close() }()
|
||||
|
||||
incomingJWT := make(chan string)
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
return client.runHTTPServer(ctx, li, incomingJWT)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return client.runOpenBrowser(ctx, li, serverURL)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
select {
|
||||
case rawJWT = <-incomingJWT:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
err = eg.Wait()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return rawJWT, nil
|
||||
}
|
||||
|
||||
func (client *AuthClient) runHTTPServer(ctx context.Context, li net.Listener, incomingJWT chan string) error {
|
||||
var srv *http.Server
|
||||
srv = &http.Server{
|
||||
BaseContext: func(li net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwt := r.FormValue("pomerium_jwt")
|
||||
if jwt == "" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
incomingJWT <- jwt
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = io.WriteString(w, "login complete, you may close this page")
|
||||
|
||||
go func() { _ = srv.Shutdown(ctx) }()
|
||||
}),
|
||||
}
|
||||
// shutdown the server when ctx is done.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = srv.Shutdown(ctx)
|
||||
}()
|
||||
err := srv.Serve(li)
|
||||
if err == http.ErrServerClosed {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error {
|
||||
dst := serverURL.ResolveReference(&url.URL{
|
||||
Path: "/.pomerium/api/v1/login",
|
||||
RawQuery: url.Values{
|
||||
"pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())},
|
||||
}.Encode(),
|
||||
})
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer clearTimeout()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", dst.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: client.cfg.tlsConfig,
|
||||
}
|
||||
hc := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get login url: %w", err)
|
||||
}
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("failed to get login url: %s", res.Status)
|
||||
}
|
||||
|
||||
bs, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read login url: %w", err)
|
||||
}
|
||||
|
||||
return openBrowser(string(bs))
|
||||
}
|
74
internal/authclient/authclient_test.go
Normal file
74
internal/authclient/authclient_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package authclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthClient(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer clearTimeout()
|
||||
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() { _ = li.Close() }()
|
||||
|
||||
go func() {
|
||||
h := chi.NewMux()
|
||||
h.Get("/.pomerium/api/v1/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(r.FormValue("pomerium_redirect_uri")))
|
||||
})
|
||||
srv := &http.Server{
|
||||
BaseContext: func(li net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
Handler: h,
|
||||
}
|
||||
_ = srv.Serve(li)
|
||||
}()
|
||||
|
||||
origOpenBrowser := openBrowser
|
||||
defer func() {
|
||||
openBrowser = origOpenBrowser
|
||||
}()
|
||||
openBrowser = func(input string) error {
|
||||
u, err := url.Parse(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u = u.ResolveReference(&url.URL{
|
||||
RawQuery: url.Values{
|
||||
"pomerium_jwt": {"TEST"},
|
||||
}.Encode(),
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
ac := New()
|
||||
rawJWT, err := ac.GetJWT(ctx, &url.URL{
|
||||
Scheme: "http",
|
||||
Host: li.Addr().String(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "TEST", rawJWT)
|
||||
}
|
27
internal/authclient/config.go
Normal file
27
internal/authclient/config.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package authclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// An Option modifies the config.
|
||||
type Option func(*config)
|
||||
|
||||
// WithTLSConfig returns an option to configure the tls config.
|
||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.tlsConfig = tlsConfig
|
||||
}
|
||||
}
|
2
internal/cliutil/clitutil.go
Normal file
2
internal/cliutil/clitutil.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package cliutil contains functionality related to CLI apps.
|
||||
package cliutil
|
144
internal/cliutil/jwtcache.go
Normal file
144
internal/cliutil/jwtcache.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package cliutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/martinlindhe/base36"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
// predefined cache errors
|
||||
var (
|
||||
ErrExpired = errors.New("expired")
|
||||
ErrInvalid = errors.New("invalid")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
// A JWTCache loads and stores JWTs.
|
||||
type JWTCache interface {
|
||||
LoadJWT(key string) (rawJWT string, err error)
|
||||
StoreJWT(key string, rawJWT string) error
|
||||
}
|
||||
|
||||
// A LocalJWTCache stores files in the user's cache directory.
|
||||
type LocalJWTCache struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
// NewLocalJWTCache creates a new LocalJWTCache.
|
||||
func NewLocalJWTCache() (*LocalJWTCache, error) {
|
||||
root, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dir := filepath.Join(root, "pomerium-cli", "jwts")
|
||||
|
||||
err = os.MkdirAll(dir, 0755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating user cache directory: %w", err)
|
||||
}
|
||||
|
||||
return &LocalJWTCache{
|
||||
dir: dir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadJWT loads a raw JWT from the local cache.
|
||||
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
||||
path := filepath.Join(cache.dir, cache.fileName(key))
|
||||
rawBS, err := ioutil.ReadFile(path)
|
||||
if os.IsNotExist(err) {
|
||||
return "", ErrNotFound
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rawJWT = string(rawBS)
|
||||
|
||||
return rawJWT, checkExpiry(rawJWT)
|
||||
}
|
||||
|
||||
// StoreJWT stores a raw JWT in the local cache.
|
||||
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
|
||||
path := filepath.Join(cache.dir, cache.fileName(key))
|
||||
err := ioutil.WriteFile(path, []byte(rawJWT), 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cache *LocalJWTCache) hash(str string) string {
|
||||
h := cryptutil.Hash("LocalJWTCache", []byte(str))
|
||||
return base36.EncodeBytes(h)
|
||||
}
|
||||
|
||||
func (cache *LocalJWTCache) fileName(key string) string {
|
||||
return cache.hash(key) + ".jwt"
|
||||
}
|
||||
|
||||
// A MemoryJWTCache stores JWTs in an in-memory map.
|
||||
type MemoryJWTCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]string
|
||||
}
|
||||
|
||||
// NewMemoryJWTCache creates a new in-memory JWT cache.
|
||||
func NewMemoryJWTCache() *MemoryJWTCache {
|
||||
return &MemoryJWTCache{entries: make(map[string]string)}
|
||||
}
|
||||
|
||||
// LoadJWT loads a JWT from the in-memory map.
|
||||
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
rawJWT, ok := cache.entries[key]
|
||||
if !ok {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
return rawJWT, checkExpiry(rawJWT)
|
||||
}
|
||||
|
||||
// StoreJWT stores a JWT in the in-memory map.
|
||||
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
cache.entries[key] = rawJWT
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkExpiry(rawJWT string) error {
|
||||
tok, err := jose.ParseSigned(rawJWT)
|
||||
if err != nil {
|
||||
return ErrInvalid
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Expiry int64 `json:"exp"`
|
||||
}
|
||||
err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
|
||||
if err != nil {
|
||||
return ErrInvalid
|
||||
}
|
||||
|
||||
expiresAt := time.Unix(claims.Expiry, 0)
|
||||
if expiresAt.Before(time.Now()) {
|
||||
return ErrExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
69
internal/cliutil/jwtcache_test.go
Normal file
69
internal/cliutil/jwtcache_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package cliutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestLocalJWTCache(t *testing.T) {
|
||||
c := &LocalJWTCache{
|
||||
dir: filepath.Join(os.TempDir(), uuid.New().String()),
|
||||
}
|
||||
|
||||
err := os.MkdirAll(c.dir, 0755)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(c.dir) }()
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
_, err := c.LoadJWT("NOTFOUND")
|
||||
assert.Equal(t, ErrNotFound, err)
|
||||
})
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
err := c.StoreJWT("INVALID", "INVALID")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = c.LoadJWT("INVALID")
|
||||
assert.Equal(t, ErrInvalid, err)
|
||||
})
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS512, Key: privateKey}, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
object, err := signer.Sign([]byte(`{"exp": ` + fmt.Sprint(time.Now().Add(-time.Second).Unix()) + `}`))
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
rawJWT, err := object.CompactSerialize()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.StoreJWT("EXPIRED", rawJWT)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.LoadJWT("EXPIRED")
|
||||
assert.Equal(t, ErrExpired, err)
|
||||
})
|
||||
}
|
60
internal/tcptunnel/config.go
Normal file
60
internal/tcptunnel/config.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package tcptunnel
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cliutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
jwtCache cliutil.JWTCache
|
||||
dstHost string
|
||||
proxyHost string
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil {
|
||||
WithJWTCache(jwtCache)(cfg)
|
||||
} else {
|
||||
log.Error().Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
|
||||
WithJWTCache(cliutil.NewMemoryJWTCache())(cfg)
|
||||
}
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// An Option modifies the config.
|
||||
type Option func(*config)
|
||||
|
||||
// WithDestinationHost returns an option to configure the destination host.
|
||||
func WithDestinationHost(dstHost string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.dstHost = dstHost
|
||||
}
|
||||
}
|
||||
|
||||
// WithJWTCache returns an option to configure the jwt cache.
|
||||
func WithJWTCache(jwtCache cliutil.JWTCache) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.jwtCache = jwtCache
|
||||
}
|
||||
}
|
||||
|
||||
// WithProxyHost returns an option to configure the proxy host.
|
||||
func WithProxyHost(proxyHost string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.proxyHost = proxyHost
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig returns an option to configure the tls config.
|
||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.tlsConfig = tlsConfig
|
||||
}
|
||||
}
|
213
internal/tcptunnel/tcptunnel.go
Normal file
213
internal/tcptunnel/tcptunnel.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect.
|
||||
package tcptunnel
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/authclient"
|
||||
"github.com/pomerium/pomerium/internal/cliutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
)
|
||||
|
||||
// A Tunnel represents a TCP tunnel over HTTP Connect.
|
||||
type Tunnel struct {
|
||||
cfg *config
|
||||
auth *authclient.AuthClient
|
||||
}
|
||||
|
||||
// New creates a new Tunnel.
|
||||
func New(options ...Option) *Tunnel {
|
||||
cfg := getConfig(options...)
|
||||
return &Tunnel{
|
||||
cfg: cfg,
|
||||
auth: authclient.New(authclient.WithTLSConfig(cfg.tlsConfig)),
|
||||
}
|
||||
}
|
||||
|
||||
// RunListener runs a network listener on the given address. For each
|
||||
// incoming connection a new TCP tunnel is established via Run.
|
||||
func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) error {
|
||||
li, err := net.Listen("tcp", listenerAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = li.Close() }()
|
||||
log.Info().Msg("tcptunnel: listening on " + li.Addr().String())
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = li.Close()
|
||||
}()
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
for {
|
||||
conn, err := li.Accept()
|
||||
if err != nil {
|
||||
// canceled, so ignore the error and return
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
log.Warn().Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
|
||||
select {
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
go func() {
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
err := tun.Run(ctx, conn)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("tcptunnel: error serving local connection")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local.
|
||||
func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
|
||||
rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey())
|
||||
switch {
|
||||
// if there is no error, or it is one of the pre-defined cliutil errors,
|
||||
// then ignore and use an empty JWT
|
||||
case err == nil,
|
||||
errors.Is(err, cliutil.ErrExpired),
|
||||
errors.Is(err, cliutil.ErrInvalid),
|
||||
errors.Is(err, cliutil.ErrNotFound):
|
||||
default:
|
||||
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
|
||||
}
|
||||
return tun.run(ctx, local, rawJWT)
|
||||
}
|
||||
|
||||
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error {
|
||||
log.Info().
|
||||
Str("dst", tun.cfg.dstHost).
|
||||
Str("proxy", tun.cfg.proxyHost).
|
||||
Bool("secure", tun.cfg.tlsConfig != nil).
|
||||
Msg("tcptunnel: opening connection")
|
||||
|
||||
hdr := http.Header{}
|
||||
if rawJWT != "" {
|
||||
hdr.Set("Authorization", "Pomerium "+rawJWT)
|
||||
}
|
||||
|
||||
req := (&http.Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: tun.cfg.dstHost},
|
||||
Host: tun.cfg.dstHost,
|
||||
Header: hdr,
|
||||
}).WithContext(ctx)
|
||||
|
||||
var remote net.Conn
|
||||
var err error
|
||||
if tun.cfg.tlsConfig != nil {
|
||||
remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
|
||||
} else {
|
||||
remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: failed to establish connection to proxy: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = remote.Close()
|
||||
log.Info().Msg("tcptunnel: connection closed")
|
||||
}()
|
||||
if done := ctx.Done(); done != nil {
|
||||
go func() {
|
||||
<-done
|
||||
_ = remote.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
err = req.Write(remote)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
br := bufio.NewReader(remote)
|
||||
res, err := http.ReadResponse(br, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: failed to read HTTP response: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = res.Body.Close()
|
||||
}()
|
||||
switch res.StatusCode {
|
||||
case http.StatusOK:
|
||||
case http.StatusMovedPermanently,
|
||||
http.StatusFound,
|
||||
http.StatusTemporaryRedirect,
|
||||
http.StatusPermanentRedirect:
|
||||
if rawJWT == "" {
|
||||
_ = remote.Close()
|
||||
|
||||
authURL, err := url.Parse(res.Header.Get("Location"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err)
|
||||
}
|
||||
|
||||
rawJWT, err = tun.auth.GetJWT(ctx, authURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
|
||||
}
|
||||
|
||||
err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
|
||||
}
|
||||
|
||||
return tun.run(ctx, local, rawJWT)
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
log.Info().Msg("tcptunnel: connection established")
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(remote, local)
|
||||
errc <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := io.Copy(local, remote)
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errc:
|
||||
if err != nil {
|
||||
err = fmt.Errorf("tcptunnel: %w", err)
|
||||
}
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *Tunnel) jwtCacheKey() string {
|
||||
return fmt.Sprintf("%s|%s|%v", tun.cfg.dstHost, tun.cfg.proxyHost, tun.cfg.tlsConfig != nil)
|
||||
}
|
91
internal/tcptunnel/tcptunnel_test.go
Normal file
91
internal/tcptunnel/tcptunnel_test.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package tcptunnel
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTunnel(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
backend, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := backend.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
ln, _, _ := bufio.NewReader(conn).ReadLine()
|
||||
assert.Equal(t, "HELLO WORLD", string(ln))
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !assert.Equal(t, "CONNECT", r.Method) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "example.com:9999", r.RequestURI) {
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(200)
|
||||
|
||||
in, _, err := w.(http.Hijacker).Hijack()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() { _ = in.Close() }()
|
||||
|
||||
out, err := net.Dial("tcp", backend.Addr().String())
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(in, out)
|
||||
errc <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := io.Copy(out, in)
|
||||
errc <- err
|
||||
}()
|
||||
<-errc
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
tun := New(
|
||||
WithDestinationHost("example.com:9999"),
|
||||
WithProxyHost(srv.Listener.Addr().String()))
|
||||
err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type readWriter struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue