TCP client command (#1696)

* add cli commands

* add jwt cache test

* add tcptunnel test

* add stdin/stdout support

* use cryptutil hash function

* doc updates

* fix log timestamp
This commit is contained in:
Caleb Doxsey 2020-12-17 12:37:28 -07:00 committed by GitHub
parent 4fbbf28a16
commit 61ab4e4837
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 923 additions and 0 deletions

105
cmd/pomerium-cli/tcp.go Normal file
View file

@ -0,0 +1,105 @@
package main
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tcptunnel"
)
var tcpCmdOptions struct {
listen string
pomeriumURL string
}
func init() {
flags := tcpCmd.Flags()
flags.StringVar(&tcpCmdOptions.listen, "listen", "127.0.0.1:0",
"local address to start a listener on")
flags.StringVar(&tcpCmdOptions.pomeriumURL, "pomerium-url", "",
"the URL of the pomerium server to connect to")
rootCmd.AddCommand(tcpCmd)
}
var tcpCmd = &cobra.Command{
Use: "tcp destination",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
dstHost := args[0]
dstHostname, _, err := net.SplitHostPort(dstHost)
if err != nil {
return fmt.Errorf("invalid destination: %w", err)
}
pomeriumURL := &url.URL{
Scheme: "https",
Host: net.JoinHostPort(dstHostname, "443"),
}
if tcpCmdOptions.pomeriumURL != "" {
pomeriumURL, err = url.Parse(tcpCmdOptions.pomeriumURL)
if err != nil {
return fmt.Errorf("invalid pomerium URL: %w", err)
}
if !strings.Contains(pomeriumURL.Host, ":") {
if pomeriumURL.Scheme == "https" {
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "443")
} else {
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "80")
}
}
}
var tlsConfig *tls.Config
if pomeriumURL.Scheme == "https" {
tlsConfig = new(tls.Config)
}
l := zerolog.New(zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.Out = os.Stderr
})).With().Timestamp().Logger()
log.SetLogger(&l)
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-c
cancel()
}()
tun := tcptunnel.New(
tcptunnel.WithDestinationHost(dstHost),
tcptunnel.WithProxyHost(pomeriumURL.Host),
tcptunnel.WithTLSConfig(tlsConfig),
)
if tcpCmdOptions.listen == "-" {
err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout})
} else {
err = tun.RunListener(ctx, tcpCmdOptions.listen)
}
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "%s\n", err.Error())
os.Exit(1)
}
return nil
},
}
type readWriter struct {
io.Reader
io.Writer
}

1
go.mod
View file

@ -32,6 +32,7 @@ require (
github.com/hashicorp/golang-lru v0.5.4
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
github.com/lithammer/shortuuid/v3 v3.0.5
github.com/martinlindhe/base36 v1.1.0
github.com/mitchellh/hashstructure/v2 v2.0.1
github.com/natefinch/atomic v0.0.0-20200526193002-18c0533a5b09
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce

2
go.sum
View file

@ -380,6 +380,8 @@ github.com/lithammer/shortuuid/v3 v3.0.5/go.mod h1:2QdoCtD4SBzugx2qs3gdR3LXY6Mcx
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=
github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4=
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/martinlindhe/base36 v1.1.0 h1:cIwvvwYse/0+1CkUPYH5ZvVIYG3JrILmQEIbLuar02Y=
github.com/martinlindhe/base36 v1.1.0/go.mod h1:+AtEs8xrBpCeYgSLoY/aJ6Wf37jtBuR0s35750M27+8=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=

View file

@ -0,0 +1,135 @@
// Package authclient contains a CLI authentication client for Pomerium.
package authclient
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"time"
"github.com/skratchdot/open-golang/open"
"golang.org/x/sync/errgroup"
)
var openBrowser = open.Run
// An AuthClient retrieves an authentication JWT via the Pomerium login API.
type AuthClient struct {
cfg *config
}
// New creates a new AuthClient.
func New(options ...Option) *AuthClient {
return &AuthClient{
cfg: getConfig(options...),
}
}
// GetJWT retrieves a JWT from Pomerium.
func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL) (rawJWT string, err error) {
li, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("failed to start listener: %w", err)
}
defer func() { _ = li.Close() }()
incomingJWT := make(chan string)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return client.runHTTPServer(ctx, li, incomingJWT)
})
eg.Go(func() error {
return client.runOpenBrowser(ctx, li, serverURL)
})
eg.Go(func() error {
select {
case rawJWT = <-incomingJWT:
case <-ctx.Done():
return ctx.Err()
}
return nil
})
err = eg.Wait()
if err != nil {
return "", err
}
return rawJWT, nil
}
func (client *AuthClient) runHTTPServer(ctx context.Context, li net.Listener, incomingJWT chan string) error {
var srv *http.Server
srv = &http.Server{
BaseContext: func(li net.Listener) context.Context {
return ctx
},
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwt := r.FormValue("pomerium_jwt")
if jwt == "" {
http.Error(w, "not found", http.StatusNotFound)
return
}
incomingJWT <- jwt
w.Header().Set("Content-Type", "text/plain")
_, _ = io.WriteString(w, "login complete, you may close this page")
go func() { _ = srv.Shutdown(ctx) }()
}),
}
// shutdown the server when ctx is done.
go func() {
<-ctx.Done()
_ = srv.Shutdown(ctx)
}()
err := srv.Serve(li)
if err == http.ErrServerClosed {
err = nil
}
return err
}
func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error {
dst := serverURL.ResolveReference(&url.URL{
Path: "/.pomerium/api/v1/login",
RawQuery: url.Values{
"pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())},
}.Encode(),
})
ctx, clearTimeout := context.WithTimeout(ctx, 10*time.Second)
defer clearTimeout()
req, err := http.NewRequestWithContext(ctx, "GET", dst.String(), nil)
if err != nil {
return err
}
transport := &http.Transport{
TLSClientConfig: client.cfg.tlsConfig,
}
hc := &http.Client{
Transport: transport,
}
res, err := hc.Do(req)
if err != nil {
return fmt.Errorf("failed to get login url: %w", err)
}
defer func() { _ = res.Body.Close() }()
if res.StatusCode/100 != 2 {
return fmt.Errorf("failed to get login url: %s", res.Status)
}
bs, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("failed to read login url: %w", err)
}
return openBrowser(string(bs))
}

View file

@ -0,0 +1,74 @@
package authclient
import (
"context"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/stretchr/testify/assert"
)
func TestAuthClient(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
li, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li.Close() }()
go func() {
h := chi.NewMux()
h.Get("/.pomerium/api/v1/login", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(r.FormValue("pomerium_redirect_uri")))
})
srv := &http.Server{
BaseContext: func(li net.Listener) context.Context {
return ctx
},
Handler: h,
}
_ = srv.Serve(li)
}()
origOpenBrowser := openBrowser
defer func() {
openBrowser = origOpenBrowser
}()
openBrowser = func(input string) error {
u, err := url.Parse(input)
if err != nil {
return err
}
u = u.ResolveReference(&url.URL{
RawQuery: url.Values{
"pomerium_jwt": {"TEST"},
}.Encode(),
})
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
_ = res.Body.Close()
return nil
}
ac := New()
rawJWT, err := ac.GetJWT(ctx, &url.URL{
Scheme: "http",
Host: li.Addr().String(),
})
assert.NoError(t, err)
assert.Equal(t, "TEST", rawJWT)
}

View file

@ -0,0 +1,27 @@
package authclient
import (
"crypto/tls"
)
type config struct {
tlsConfig *tls.Config
}
func getConfig(options ...Option) *config {
cfg := new(config)
for _, o := range options {
o(cfg)
}
return cfg
}
// An Option modifies the config.
type Option func(*config)
// WithTLSConfig returns an option to configure the tls config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
cfg.tlsConfig = tlsConfig
}
}

View file

@ -0,0 +1,2 @@
// Package cliutil contains functionality related to CLI apps.
package cliutil

View file

@ -0,0 +1,144 @@
package cliutil
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/martinlindhe/base36"
"gopkg.in/square/go-jose.v2"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// predefined cache errors
var (
ErrExpired = errors.New("expired")
ErrInvalid = errors.New("invalid")
ErrNotFound = errors.New("not found")
)
// A JWTCache loads and stores JWTs.
type JWTCache interface {
LoadJWT(key string) (rawJWT string, err error)
StoreJWT(key string, rawJWT string) error
}
// A LocalJWTCache stores files in the user's cache directory.
type LocalJWTCache struct {
dir string
}
// NewLocalJWTCache creates a new LocalJWTCache.
func NewLocalJWTCache() (*LocalJWTCache, error) {
root, err := os.UserCacheDir()
if err != nil {
return nil, err
}
dir := filepath.Join(root, "pomerium-cli", "jwts")
err = os.MkdirAll(dir, 0755)
if err != nil {
return nil, fmt.Errorf("error creating user cache directory: %w", err)
}
return &LocalJWTCache{
dir: dir,
}, nil
}
// LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
path := filepath.Join(cache.dir, cache.fileName(key))
rawBS, err := ioutil.ReadFile(path)
if os.IsNotExist(err) {
return "", ErrNotFound
} else if err != nil {
return "", err
}
rawJWT = string(rawBS)
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a raw JWT in the local cache.
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := ioutil.WriteFile(path, []byte(rawJWT), 0600)
if err != nil {
return err
}
return nil
}
func (cache *LocalJWTCache) hash(str string) string {
h := cryptutil.Hash("LocalJWTCache", []byte(str))
return base36.EncodeBytes(h)
}
func (cache *LocalJWTCache) fileName(key string) string {
return cache.hash(key) + ".jwt"
}
// A MemoryJWTCache stores JWTs in an in-memory map.
type MemoryJWTCache struct {
mu sync.Mutex
entries map[string]string
}
// NewMemoryJWTCache creates a new in-memory JWT cache.
func NewMemoryJWTCache() *MemoryJWTCache {
return &MemoryJWTCache{entries: make(map[string]string)}
}
// LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
cache.mu.Lock()
defer cache.mu.Unlock()
rawJWT, ok := cache.entries[key]
if !ok {
return "", ErrNotFound
}
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a JWT in the in-memory map.
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.entries[key] = rawJWT
return nil
}
func checkExpiry(rawJWT string) error {
tok, err := jose.ParseSigned(rawJWT)
if err != nil {
return ErrInvalid
}
var claims struct {
Expiry int64 `json:"exp"`
}
err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
if err != nil {
return ErrInvalid
}
expiresAt := time.Unix(claims.Expiry, 0)
if expiresAt.Before(time.Now()) {
return ErrExpired
}
return nil
}

View file

@ -0,0 +1,69 @@
package cliutil
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
)
func TestLocalJWTCache(t *testing.T) {
c := &LocalJWTCache{
dir: filepath.Join(os.TempDir(), uuid.New().String()),
}
err := os.MkdirAll(c.dir, 0755)
if !assert.NoError(t, err) {
return
}
defer func() { _ = os.RemoveAll(c.dir) }()
t.Run("NotFound", func(t *testing.T) {
_, err := c.LoadJWT("NOTFOUND")
assert.Equal(t, ErrNotFound, err)
})
t.Run("Invalid", func(t *testing.T) {
err := c.StoreJWT("INVALID", "INVALID")
if !assert.NoError(t, err) {
return
}
_, err = c.LoadJWT("INVALID")
assert.Equal(t, ErrInvalid, err)
})
t.Run("Expired", func(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if !assert.NoError(t, err) {
return
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS512, Key: privateKey}, nil)
if !assert.NoError(t, err) {
return
}
object, err := signer.Sign([]byte(`{"exp": ` + fmt.Sprint(time.Now().Add(-time.Second).Unix()) + `}`))
if !assert.NoError(t, err) {
return
}
rawJWT, err := object.CompactSerialize()
if !assert.NoError(t, err) {
return
}
err = c.StoreJWT("EXPIRED", rawJWT)
if !assert.NoError(t, err) {
return
}
_, err = c.LoadJWT("EXPIRED")
assert.Equal(t, ErrExpired, err)
})
}

View file

@ -0,0 +1,60 @@
package tcptunnel
import (
"crypto/tls"
"github.com/pomerium/pomerium/internal/cliutil"
"github.com/pomerium/pomerium/internal/log"
)
type config struct {
jwtCache cliutil.JWTCache
dstHost string
proxyHost string
tlsConfig *tls.Config
}
func getConfig(options ...Option) *config {
cfg := new(config)
if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil {
WithJWTCache(jwtCache)(cfg)
} else {
log.Error().Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
WithJWTCache(cliutil.NewMemoryJWTCache())(cfg)
}
for _, o := range options {
o(cfg)
}
return cfg
}
// An Option modifies the config.
type Option func(*config)
// WithDestinationHost returns an option to configure the destination host.
func WithDestinationHost(dstHost string) Option {
return func(cfg *config) {
cfg.dstHost = dstHost
}
}
// WithJWTCache returns an option to configure the jwt cache.
func WithJWTCache(jwtCache cliutil.JWTCache) Option {
return func(cfg *config) {
cfg.jwtCache = jwtCache
}
}
// WithProxyHost returns an option to configure the proxy host.
func WithProxyHost(proxyHost string) Option {
return func(cfg *config) {
cfg.proxyHost = proxyHost
}
}
// WithTLSConfig returns an option to configure the tls config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
cfg.tlsConfig = tlsConfig
}
}

View file

@ -0,0 +1,213 @@
// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect.
package tcptunnel
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/pomerium/pomerium/internal/authclient"
"github.com/pomerium/pomerium/internal/cliutil"
"github.com/pomerium/pomerium/internal/log"
backoff "github.com/cenkalti/backoff/v4"
)
// A Tunnel represents a TCP tunnel over HTTP Connect.
type Tunnel struct {
cfg *config
auth *authclient.AuthClient
}
// New creates a new Tunnel.
func New(options ...Option) *Tunnel {
cfg := getConfig(options...)
return &Tunnel{
cfg: cfg,
auth: authclient.New(authclient.WithTLSConfig(cfg.tlsConfig)),
}
}
// RunListener runs a network listener on the given address. For each
// incoming connection a new TCP tunnel is established via Run.
func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) error {
li, err := net.Listen("tcp", listenerAddress)
if err != nil {
return err
}
defer func() { _ = li.Close() }()
log.Info().Msg("tcptunnel: listening on " + li.Addr().String())
go func() {
<-ctx.Done()
_ = li.Close()
}()
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
for {
conn, err := li.Accept()
if err != nil {
// canceled, so ignore the error and return
if ctx.Err() != nil {
return nil
}
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Warn().Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
select {
case <-time.After(bo.NextBackOff()):
case <-ctx.Done():
return ctx.Err()
}
continue
}
return err
}
bo.Reset()
go func() {
defer func() { _ = conn.Close() }()
err := tun.Run(ctx, conn)
if err != nil {
log.Error().Err(err).Msg("tcptunnel: error serving local connection")
}
}()
}
}
// Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local.
func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey())
switch {
// if there is no error, or it is one of the pre-defined cliutil errors,
// then ignore and use an empty JWT
case err == nil,
errors.Is(err, cliutil.ErrExpired),
errors.Is(err, cliutil.ErrInvalid),
errors.Is(err, cliutil.ErrNotFound):
default:
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
}
return tun.run(ctx, local, rawJWT)
}
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string) error {
log.Info().
Str("dst", tun.cfg.dstHost).
Str("proxy", tun.cfg.proxyHost).
Bool("secure", tun.cfg.tlsConfig != nil).
Msg("tcptunnel: opening connection")
hdr := http.Header{}
if rawJWT != "" {
hdr.Set("Authorization", "Pomerium "+rawJWT)
}
req := (&http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: tun.cfg.dstHost},
Host: tun.cfg.dstHost,
Header: hdr,
}).WithContext(ctx)
var remote net.Conn
var err error
if tun.cfg.tlsConfig != nil {
remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
} else {
remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
}
if err != nil {
return fmt.Errorf("tcptunnel: failed to establish connection to proxy: %w", err)
}
defer func() {
_ = remote.Close()
log.Info().Msg("tcptunnel: connection closed")
}()
if done := ctx.Done(); done != nil {
go func() {
<-done
_ = remote.Close()
}()
}
err = req.Write(remote)
if err != nil {
return err
}
br := bufio.NewReader(remote)
res, err := http.ReadResponse(br, req)
if err != nil {
return fmt.Errorf("tcptunnel: failed to read HTTP response: %w", err)
}
defer func() {
_ = res.Body.Close()
}()
switch res.StatusCode {
case http.StatusOK:
case http.StatusMovedPermanently,
http.StatusFound,
http.StatusTemporaryRedirect,
http.StatusPermanentRedirect:
if rawJWT == "" {
_ = remote.Close()
authURL, err := url.Parse(res.Header.Get("Location"))
if err != nil {
return fmt.Errorf("tcptunnel: invalid redirect location for authentication: %w", err)
}
rawJWT, err = tun.auth.GetJWT(ctx, authURL)
if err != nil {
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
}
err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT)
if err != nil {
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
}
return tun.run(ctx, local, rawJWT)
}
fallthrough
default:
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
}
log.Info().Msg("tcptunnel: connection established")
errc := make(chan error, 2)
go func() {
_, err := io.Copy(remote, local)
errc <- err
}()
go func() {
_, err := io.Copy(local, remote)
errc <- err
}()
select {
case err := <-errc:
if err != nil {
err = fmt.Errorf("tcptunnel: %w", err)
}
return err
case <-ctx.Done():
return nil
}
}
func (tun *Tunnel) jwtCacheKey() string {
return fmt.Sprintf("%s|%s|%v", tun.cfg.dstHost, tun.cfg.proxyHost, tun.cfg.tlsConfig != nil)
}

View file

@ -0,0 +1,91 @@
package tcptunnel
import (
"bufio"
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTunnel(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
backend, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = backend.Close() }()
go func() {
for {
conn, err := backend.Accept()
if err != nil {
return
}
go func() {
defer func() { _ = conn.Close() }()
ln, _, _ := bufio.NewReader(conn).ReadLine()
assert.Equal(t, "HELLO WORLD", string(ln))
}()
}
}()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !assert.Equal(t, "CONNECT", r.Method) {
return
}
if !assert.Equal(t, "example.com:9999", r.RequestURI) {
return
}
w.WriteHeader(200)
in, _, err := w.(http.Hijacker).Hijack()
if !assert.NoError(t, err) {
return
}
defer func() { _ = in.Close() }()
out, err := net.Dial("tcp", backend.Addr().String())
if !assert.NoError(t, err) {
return
}
defer func() { _ = out.Close() }()
errc := make(chan error, 2)
go func() {
_, err := io.Copy(in, out)
errc <- err
}()
go func() {
_, err := io.Copy(out, in)
errc <- err
}()
<-errc
}))
defer srv.Close()
var buf bytes.Buffer
tun := New(
WithDestinationHost("example.com:9999"),
WithProxyHost(srv.Listener.Addr().String()))
err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf})
if !assert.NoError(t, err) {
return
}
}
type readWriter struct {
io.Reader
io.Writer
}