rm cli code (#2824)

This commit is contained in:
Denis Mishin 2021-12-15 16:25:21 -05:00 committed by GitHub
parent 41877e166b
commit 9466d7ef53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 2 additions and 1405 deletions

View file

@ -12,8 +12,3 @@ analyze:
type: go
target: github.com/pomerium/pomerium/cmd/pomerium
path: cmd/pomerium
- name: github.com/pomerium/pomerium/cmd/pomerium-cli
type: go
target: github.com/pomerium/pomerium/cmd/pomerium-cli
path: cmd/pomerium-cli

View file

@ -164,7 +164,6 @@ jobs:
run: |
make build-deps
make build
make build NAME=pomerium-cli
- name: save binary
uses: actions/upload-artifact@v2

View file

@ -1,105 +0,0 @@
package main
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"time"
)
func cachePath() string {
root, err := os.UserCacheDir()
if err != nil {
fatalf("error getting user cache dir: %v", err)
}
return filepath.Join(root, "pomerium-cli", "exec-credential")
}
func cachedCredentialPath(serverURL string) string {
h := sha256.New()
_, _ = h.Write([]byte(serverURL))
id := hex.EncodeToString(h.Sum(nil))
return filepath.Join(cachePath(), id+".json")
}
func clearAllCachedCredentials() {
_ = filepath.Walk(cachePath(), func(p string, fi fs.FileInfo, err error) error {
if err != nil {
return err
}
if fi.IsDir() {
return nil
}
return os.Remove(p)
})
}
func clearCachedCredential(serverURL string) {
fn := cachedCredentialPath(serverURL)
_ = os.Remove(fn)
}
func loadCachedCredential(serverURL string) *ExecCredential {
fn := cachedCredentialPath(serverURL)
f, err := os.Open(fn)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var creds ExecCredential
err = json.NewDecoder(f).Decode(&creds)
if err != nil {
_ = os.Remove(fn)
return nil
}
if creds.Status == nil {
_ = os.Remove(fn)
return nil
}
ts := creds.Status.ExpirationTimestamp
if ts.IsZero() || ts.Before(time.Now()) {
_ = os.Remove(fn)
return nil
}
return &creds
}
func saveCachedCredential(serverURL string, creds *ExecCredential) {
fn := cachedCredentialPath(serverURL)
err := os.MkdirAll(filepath.Dir(fn), 0o755)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create cache directory: %v", err)
return
}
f, err := os.Create(fn)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create cache file: %v", err)
return
}
err = json.NewEncoder(f).Encode(creds)
if err != nil {
_ = f.Close()
fmt.Fprintf(os.Stderr, "failed to encode credentials to cache file: %v", err)
return
}
err = f.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to close cache file: %v", err)
return
}
}

View file

@ -1,172 +0,0 @@
package main
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"os"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/spf13/cobra"
"github.com/pomerium/pomerium/internal/authclient"
)
func init() {
addBrowserFlags(kubernetesExecCredentialCmd)
addTLSFlags(kubernetesExecCredentialCmd)
kubernetesCmd.AddCommand(kubernetesExecCredentialCmd)
kubernetesCmd.AddCommand(kubernetesFlushCredentialsCmd)
rootCmd.AddCommand(kubernetesCmd)
}
var kubernetesCmd = &cobra.Command{
Use: "k8s",
}
var kubernetesFlushCredentialsCmd = &cobra.Command{
Use: "flush-credentials [API Server URL]",
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
clearAllCachedCredentials()
} else {
clearCachedCredential(args[0])
}
return nil
},
}
var kubernetesExecCredentialCmd = &cobra.Command{
Use: "exec-credential",
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return fmt.Errorf("server url is required")
}
serverURL, err := url.Parse(args[0])
if err != nil {
return fmt.Errorf("invalid server url: %v", err)
}
creds := loadCachedCredential(serverURL.String())
if creds != nil {
printCreds(creds)
return nil
}
var tlsConfig *tls.Config
if serverURL.Scheme == "https" {
tlsConfig = getTLSConfig()
}
ac := authclient.New(
authclient.WithBrowserCommand(browserOptions.command),
authclient.WithTLSConfig(tlsConfig))
rawJWT, err := ac.GetJWT(context.Background(), serverURL)
if err != nil {
fatalf("%s", err)
}
creds, err = parseToken(rawJWT)
if err != nil {
fatalf("%s", err)
}
saveCachedCredential(serverURL.String(), creds)
printCreds(creds)
return nil
},
}
func parseToken(rawjwt string) (*ExecCredential, error) {
tok, err := jose.ParseSigned(rawjwt)
if err != nil {
return nil, err
}
var claims struct {
Expiry int64 `json:"exp"`
}
err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
if err != nil {
return nil, err
}
expiresAt := time.Unix(claims.Expiry, 0)
if expiresAt.IsZero() {
expiresAt = time.Now().Add(time.Hour)
}
return &ExecCredential{
TypeMeta: TypeMeta{
APIVersion: "client.authentication.k8s.io/v1beta1",
Kind: "ExecCredential",
},
Status: &ExecCredentialStatus{
ExpirationTimestamp: expiresAt,
Token: "Pomerium-" + rawjwt,
},
}, nil
}
func printCreds(creds *ExecCredential) {
bs, err := json.Marshal(creds)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to encode credentials: %v\n", err)
}
fmt.Println(string(bs))
}
// TypeMeta describes an individual object in an API response or request
// with strings representing the type of the object and its API schema version.
// Structures that are versioned or persisted should inline TypeMeta.
//
// +k8s:deepcopy-gen=false
type TypeMeta struct {
// Kind is a string value representing the REST resource this object represents.
// Servers may infer this from the endpoint the client submits requests to.
// Cannot be updated.
// In CamelCase.
// More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds
// +optional
Kind string `json:"kind,omitempty" protobuf:"bytes,1,opt,name=kind"`
// APIVersion defines the versioned schema of this representation of an object.
// Servers should convert recognized schemas to the latest internal value, and
// may reject unrecognized values.
// More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources
// +optional
APIVersion string `json:"apiVersion,omitempty" protobuf:"bytes,2,opt,name=apiVersion"`
}
// ExecCredential is used by exec-based plugins to communicate credentials to
// HTTP transports.
type ExecCredential struct {
TypeMeta `json:",inline"`
// Status is filled in by the plugin and holds the credentials that the transport
// should use to contact the API.
// +optional
Status *ExecCredentialStatus `json:"status,omitempty"`
}
// ExecCredentialStatus holds credentials for the transport to use.
//
// Token and ClientKeyData are sensitive fields. This data should only be
// transmitted in-memory between client and exec plugin process. Exec plugin
// itself should at least be protected via file permissions.
type ExecCredentialStatus struct {
// ExpirationTimestamp indicates a time when the provided credentials expire.
// +optional
ExpirationTimestamp time.Time `json:"expirationTimestamp,omitempty"`
// Token is a bearer token used by the client for request authentication.
Token string `json:"token,omitempty"`
// PEM-encoded client TLS certificates (including intermediates, if any).
ClientCertificateData string `json:"clientCertificateData,omitempty"`
// PEM-encoded private key for the above certificate.
ClientKeyData string `json:"clientKeyData,omitempty"`
}

View file

@ -1,69 +0,0 @@
// Package main implements the pomerium-cli.
package main
import (
"crypto/tls"
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
var rootCmd = &cobra.Command{
Use: "pomerium-cli",
}
func main() {
err := rootCmd.Execute()
if err != nil {
fatalf("%s", err.Error())
}
}
func fatalf(msg string, args ...interface{}) {
fmt.Fprintf(os.Stderr, msg+"\n", args...)
os.Exit(1)
}
var tlsOptions struct {
disableTLSVerification bool
alternateCAPath string
caCert string
}
func addTLSFlags(cmd *cobra.Command) {
flags := cmd.Flags()
flags.BoolVar(&tlsOptions.disableTLSVerification, "disable-tls-verification", false,
"disables TLS verification")
flags.StringVar(&tlsOptions.alternateCAPath, "alternate-ca-path", "",
"path to CA certificate to use for HTTP requests")
flags.StringVar(&tlsOptions.caCert, "ca-cert", "",
"base64-encoded CA TLS certificate to use for HTTP requests")
}
func getTLSConfig() *tls.Config {
cfg := new(tls.Config)
if tlsOptions.disableTLSVerification {
cfg.InsecureSkipVerify = true
}
if tlsOptions.caCert != "" {
var err error
cfg.RootCAs, err = cryptutil.GetCertPool(tlsOptions.caCert, tlsOptions.alternateCAPath)
if err != nil {
fatalf("%s", err)
}
}
return cfg
}
var browserOptions struct {
command string
}
func addBrowserFlags(cmd *cobra.Command) {
flags := cmd.Flags()
flags.StringVar(&browserOptions.command, "browser-cmd", "",
"custom browser command to run when opening a URL")
}

View file

@ -1,113 +0,0 @@
package main
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"golang.org/x/term"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tcptunnel"
)
var tcpCmdOptions struct {
listen string
pomeriumURL string
}
func init() {
addTLSFlags(tcpCmd)
flags := tcpCmd.Flags()
flags.StringVar(&tcpCmdOptions.listen, "listen", "127.0.0.1:0",
"local address to start a listener on")
flags.StringVar(&tcpCmdOptions.pomeriumURL, "pomerium-url", "",
"the URL of the pomerium server to connect to")
rootCmd.AddCommand(tcpCmd)
}
var tcpCmd = &cobra.Command{
Use: "tcp destination",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
dstHost := args[0]
dstHostname, _, err := net.SplitHostPort(dstHost)
if err != nil {
return fmt.Errorf("invalid destination: %w", err)
}
pomeriumURL := &url.URL{
Scheme: "https",
Host: net.JoinHostPort(dstHostname, "443"),
}
if tcpCmdOptions.pomeriumURL != "" {
pomeriumURL, err = url.Parse(tcpCmdOptions.pomeriumURL)
if err != nil {
return fmt.Errorf("invalid pomerium URL: %w", err)
}
if !strings.Contains(pomeriumURL.Host, ":") {
if pomeriumURL.Scheme == "https" {
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "443")
} else {
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "80")
}
}
}
var tlsConfig *tls.Config
if pomeriumURL.Scheme == "https" {
tlsConfig = getTLSConfig()
}
l := zerolog.New(zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.Out = os.Stderr
w.TimeFormat = time.RFC3339
if !term.IsTerminal(int(os.Stdin.Fd())) {
w.NoColor = !term.IsTerminal(int(os.Stdin.Fd()))
}
})).With().Timestamp().Logger()
log.SetLogger(&l)
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-c
cancel()
}()
tun := tcptunnel.New(
tcptunnel.WithBrowserCommand(browserOptions.command),
tcptunnel.WithDestinationHost(dstHost),
tcptunnel.WithProxyHost(pomeriumURL.Host),
tcptunnel.WithTLSConfig(tlsConfig),
)
if tcpCmdOptions.listen == "-" {
err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout})
} else {
err = tun.RunListener(ctx, tcpCmdOptions.listen)
}
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "%s\n", err.Error())
os.Exit(1)
}
return nil
},
}
type readWriter struct {
io.Reader
io.Writer
}

View file

@ -1,22 +0,0 @@
package main
import (
"fmt"
"github.com/spf13/cobra"
"github.com/pomerium/pomerium/internal/version"
)
func init() {
rootCmd.AddCommand(versionCmd)
}
var versionCmd = &cobra.Command{
Use: "version",
Short: "version",
Long: `Print the cli version.`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("pomerium:", version.FullVersion())
},
}

4
go.mod
View file

@ -52,8 +52,7 @@ require (
github.com/rs/zerolog v1.26.0
github.com/scylladb/go-set v1.0.2
github.com/shirou/gopsutil/v3 v3.21.11
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/spf13/cobra v1.2.1
github.com/spf13/cobra v1.2.1 // indirect
github.com/spf13/viper v1.10.0
github.com/stretchr/testify v1.7.0
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
@ -67,7 +66,6 @@ require (
golang.org/x/net v0.0.0-20211111083644-e5c967477495
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b
google.golang.org/api v0.62.0
google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa
google.golang.org/grpc v1.42.0

4
go.sum
View file

@ -1236,8 +1236,6 @@ github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sivchari/tenv v1.4.7 h1:FdTpgRlTue5eb5nXIYgS/lyVXSjugU8UUVDwhP1NLU8=
github.com/sivchari/tenv v1.4.7/go.mod h1:5nF+bITvkebQVanjU6IuMbvIot/7ReNsUV7I5NbprB0=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@ -1728,8 +1726,6 @@ golang.org/x/sys v0.0.0-20211205182925-97ca703d548d h1:FjkYO/PPp4Wi0EAUOVLxePm7q
golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b h1:9zKuko04nR4gjZ4+DNjHqRlAJqbJETHwiNKDqTfOjfE=
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -1,149 +0,0 @@
// Package authclient contains a CLI authentication client for Pomerium.
package authclient
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"time"
"golang.org/x/sync/errgroup"
)
// An AuthClient retrieves an authentication JWT via the Pomerium login API.
type AuthClient struct {
cfg *config
}
// New creates a new AuthClient.
func New(options ...Option) *AuthClient {
return &AuthClient{
cfg: getConfig(options...),
}
}
// GetJWT retrieves a JWT from Pomerium.
func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL) (rawJWT string, err error) {
li, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("failed to start listener: %w", err)
}
defer func() { _ = li.Close() }()
incomingJWT := make(chan string)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return client.runHTTPServer(ctx, li, incomingJWT)
})
eg.Go(func() error {
return client.runOpenBrowser(ctx, li, serverURL)
})
eg.Go(func() error {
select {
case rawJWT = <-incomingJWT:
case <-ctx.Done():
return ctx.Err()
}
return nil
})
err = eg.Wait()
if err != nil {
return "", err
}
return rawJWT, nil
}
func (client *AuthClient) runHTTPServer(ctx context.Context, li net.Listener, incomingJWT chan string) error {
var srv *http.Server
srv = &http.Server{
BaseContext: func(li net.Listener) context.Context {
return ctx
},
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwt := r.FormValue("pomerium_jwt")
if jwt == "" {
http.Error(w, "not found", http.StatusNotFound)
return
}
incomingJWT <- jwt
w.Header().Set("Content-Type", "text/plain")
_, _ = io.WriteString(w, "login complete, you may close this page")
go func() { _ = srv.Shutdown(ctx) }()
}),
}
// shutdown the server when ctx is done.
go func() {
<-ctx.Done()
_ = srv.Shutdown(ctx)
}()
err := srv.Serve(li)
if err == http.ErrServerClosed {
err = nil
}
return err
}
func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error {
browserURL := new(url.URL)
*browserURL = *serverURL
// remove unnecessary ports to avoid HMAC error
if browserURL.Scheme == "http" && browserURL.Host == browserURL.Hostname()+":80" {
browserURL.Host = browserURL.Hostname()
} else if browserURL.Scheme == "https" && browserURL.Host == browserURL.Hostname()+":443" {
browserURL.Host = browserURL.Hostname()
}
dst := browserURL.ResolveReference(&url.URL{
Path: "/.pomerium/api/v1/login",
RawQuery: url.Values{
"pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())},
}.Encode(),
})
ctx, clearTimeout := context.WithTimeout(ctx, 10*time.Second)
defer clearTimeout()
req, err := http.NewRequestWithContext(ctx, "GET", dst.String(), nil)
if err != nil {
return err
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = client.cfg.tlsConfig
hc := &http.Client{
Transport: transport,
}
res, err := hc.Do(req)
if err != nil {
return fmt.Errorf("failed to get login url: %w", err)
}
defer func() { _ = res.Body.Close() }()
if res.StatusCode/100 != 2 {
return fmt.Errorf("failed to get login url: %s", res.Status)
}
bs, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("failed to read login url: %w", err)
}
err = client.cfg.open(string(bs))
if err != nil {
return fmt.Errorf("failed to open browser url: %w", err)
}
_, _ = fmt.Fprintf(os.Stderr, "Your browser has been opened to visit:\n\n%s\n\n", string(bs))
return nil
}

View file

@ -1,70 +0,0 @@
package authclient
import (
"context"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/stretchr/testify/assert"
)
func TestAuthClient(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
li, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li.Close() }()
go func() {
h := chi.NewMux()
h.Get("/.pomerium/api/v1/login", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(r.FormValue("pomerium_redirect_uri")))
})
srv := &http.Server{
BaseContext: func(li net.Listener) context.Context {
return ctx
},
Handler: h,
}
_ = srv.Serve(li)
}()
ac := New()
ac.cfg.open = func(input string) error {
u, err := url.Parse(input)
if err != nil {
return err
}
u = u.ResolveReference(&url.URL{
RawQuery: url.Values{
"pomerium_jwt": {"TEST"},
}.Encode(),
})
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
_ = res.Body.Close()
return nil
}
rawJWT, err := ac.GetJWT(ctx, &url.URL{
Scheme: "http",
Host: li.Addr().String(),
})
assert.NoError(t, err)
assert.Equal(t, "TEST", rawJWT)
}

View file

@ -1,44 +0,0 @@
package authclient
import (
"crypto/tls"
"github.com/skratchdot/open-golang/open"
)
type config struct {
open func(rawURL string) error
tlsConfig *tls.Config
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithBrowserCommand("")(cfg)
for _, o := range options {
o(cfg)
}
return cfg
}
// An Option modifies the config.
type Option func(*config)
// WithBrowserCommand returns an option to configure the browser command.
func WithBrowserCommand(browserCommand string) Option {
return func(cfg *config) {
if browserCommand == "" {
cfg.open = open.Run
} else {
cfg.open = func(rawURL string) error {
return open.RunWith(rawURL, browserCommand)
}
}
}
}
// WithTLSConfig returns an option to configure the tls config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
cfg.tlsConfig = tlsConfig.Clone()
}
}

View file

@ -1,2 +0,0 @@
// Package cliutil contains functionality related to CLI apps.
package cliutil

View file

@ -1,164 +0,0 @@
package cliutil
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/martinlindhe/base36"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// predefined cache errors
var (
ErrExpired = errors.New("expired")
ErrInvalid = errors.New("invalid")
ErrNotFound = errors.New("not found")
)
// A JWTCache loads and stores JWTs.
type JWTCache interface {
DeleteJWT(key string) error
LoadJWT(key string) (rawJWT string, err error)
StoreJWT(key string, rawJWT string) error
}
// A LocalJWTCache stores files in the user's cache directory.
type LocalJWTCache struct {
dir string
}
// NewLocalJWTCache creates a new LocalJWTCache.
func NewLocalJWTCache() (*LocalJWTCache, error) {
root, err := os.UserCacheDir()
if err != nil {
return nil, err
}
dir := filepath.Join(root, "pomerium-cli", "jwts")
err = os.MkdirAll(dir, 0o755)
if err != nil {
return nil, fmt.Errorf("error creating user cache directory: %w", err)
}
return &LocalJWTCache{
dir: dir,
}, nil
}
// DeleteJWT deletes a raw JWT from the local cache.
func (cache *LocalJWTCache) DeleteJWT(key string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := os.Remove(path)
if os.IsNotExist(err) {
err = nil
}
return err
}
// LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
path := filepath.Join(cache.dir, cache.fileName(key))
rawBS, err := ioutil.ReadFile(path)
if os.IsNotExist(err) {
return "", ErrNotFound
} else if err != nil {
return "", err
}
rawJWT = string(rawBS)
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a raw JWT in the local cache.
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := ioutil.WriteFile(path, []byte(rawJWT), 0o600)
if err != nil {
return err
}
return nil
}
func (cache *LocalJWTCache) hash(str string) string {
h := cryptutil.Hash("LocalJWTCache", []byte(str))
return base36.EncodeBytes(h)
}
func (cache *LocalJWTCache) fileName(key string) string {
return cache.hash(key) + ".jwt"
}
// A MemoryJWTCache stores JWTs in an in-memory map.
type MemoryJWTCache struct {
mu sync.Mutex
entries map[string]string
}
// NewMemoryJWTCache creates a new in-memory JWT cache.
func NewMemoryJWTCache() *MemoryJWTCache {
return &MemoryJWTCache{entries: make(map[string]string)}
}
// DeleteJWT deletes a JWT from the in-memory map.
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
delete(cache.entries, key)
return nil
}
// LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
cache.mu.Lock()
defer cache.mu.Unlock()
rawJWT, ok := cache.entries[key]
if !ok {
return "", ErrNotFound
}
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a JWT in the in-memory map.
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.entries[key] = rawJWT
return nil
}
func checkExpiry(rawJWT string) error {
tok, err := jose.ParseSigned(rawJWT)
if err != nil {
return ErrInvalid
}
var claims struct {
Expiry int64 `json:"exp"`
}
err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
if err != nil {
return ErrInvalid
}
expiresAt := time.Unix(claims.Expiry, 0)
if expiresAt.Before(time.Now()) {
return ErrExpired
}
return nil
}

View file

@ -1,69 +0,0 @@
package cliutil
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestLocalJWTCache(t *testing.T) {
c := &LocalJWTCache{
dir: filepath.Join(os.TempDir(), uuid.New().String()),
}
err := os.MkdirAll(c.dir, 0o755)
if !assert.NoError(t, err) {
return
}
defer func() { _ = os.RemoveAll(c.dir) }()
t.Run("NotFound", func(t *testing.T) {
_, err := c.LoadJWT("NOTFOUND")
assert.Equal(t, ErrNotFound, err)
})
t.Run("Invalid", func(t *testing.T) {
err := c.StoreJWT("INVALID", "INVALID")
if !assert.NoError(t, err) {
return
}
_, err = c.LoadJWT("INVALID")
assert.Equal(t, ErrInvalid, err)
})
t.Run("Expired", func(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if !assert.NoError(t, err) {
return
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS512, Key: privateKey}, nil)
if !assert.NoError(t, err) {
return
}
object, err := signer.Sign([]byte(`{"exp": ` + fmt.Sprint(time.Now().Add(-time.Second).Unix()) + `}`))
if !assert.NoError(t, err) {
return
}
rawJWT, err := object.CompactSerialize()
if !assert.NoError(t, err) {
return
}
err = c.StoreJWT("EXPIRED", rawJWT)
if !assert.NoError(t, err) {
return
}
_, err = c.LoadJWT("EXPIRED")
assert.Equal(t, ErrExpired, err)
})
}

View file

@ -1,73 +0,0 @@
package tcptunnel
import (
"context"
"crypto/tls"
"github.com/pomerium/pomerium/internal/cliutil"
"github.com/pomerium/pomerium/internal/log"
)
type config struct {
jwtCache cliutil.JWTCache
dstHost string
proxyHost string
tlsConfig *tls.Config
browserConfig string
}
func getConfig(options ...Option) *config {
cfg := new(config)
if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil {
WithJWTCache(jwtCache)(cfg)
} else {
log.Error(context.TODO()).Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
WithJWTCache(cliutil.NewMemoryJWTCache())(cfg)
}
for _, o := range options {
o(cfg)
}
return cfg
}
// An Option modifies the config.
type Option func(*config)
// WithBrowserCommand returns an option to configure the browser command.
func WithBrowserCommand(browserCommand string) Option {
return func(cfg *config) {
cfg.browserConfig = browserCommand
}
}
// WithDestinationHost returns an option to configure the destination host.
func WithDestinationHost(dstHost string) Option {
return func(cfg *config) {
cfg.dstHost = dstHost
}
}
// WithJWTCache returns an option to configure the jwt cache.
func WithJWTCache(jwtCache cliutil.JWTCache) Option {
return func(cfg *config) {
cfg.jwtCache = jwtCache
}
}
// WithProxyHost returns an option to configure the proxy host.
func WithProxyHost(proxyHost string) Option {
return func(cfg *config) {
cfg.proxyHost = proxyHost
}
}
// WithTLSConfig returns an option to configure the tls config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
if tlsConfig != nil {
tlsConfig = tlsConfig.Clone()
tlsConfig.NextProtos = []string{"http/1.1"} // disable http/2 in ALPN
}
cfg.tlsConfig = tlsConfig
}
}

View file

@ -1,227 +0,0 @@
// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect.
package tcptunnel
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/pomerium/pomerium/internal/authclient"
"github.com/pomerium/pomerium/internal/cliutil"
"github.com/pomerium/pomerium/internal/log"
backoff "github.com/cenkalti/backoff/v4"
)
// A Tunnel represents a TCP tunnel over HTTP Connect.
type Tunnel struct {
cfg *config
auth *authclient.AuthClient
}
// New creates a new Tunnel.
func New(options ...Option) *Tunnel {
cfg := getConfig(options...)
return &Tunnel{
cfg: cfg,
auth: authclient.New(
authclient.WithBrowserCommand(cfg.browserConfig),
authclient.WithTLSConfig(cfg.tlsConfig)),
}
}
// RunListener runs a network listener on the given address. For each
// incoming connection a new TCP tunnel is established via Run.
func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) error {
li, err := net.Listen("tcp", listenerAddress)
if err != nil {
return err
}
defer func() { _ = li.Close() }()
log.Info(ctx).Msg("tcptunnel: listening on " + li.Addr().String())
go func() {
<-ctx.Done()
_ = li.Close()
}()
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
for {
conn, err := li.Accept()
if err != nil {
// canceled, so ignore the error and return
if ctx.Err() != nil {
return nil
}
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Warn(ctx).Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
select {
case <-time.After(bo.NextBackOff()):
case <-ctx.Done():
return ctx.Err()
}
continue
}
return err
}
bo.Reset()
go func() {
defer func() { _ = conn.Close() }()
err := tun.Run(ctx, conn)
if err != nil {
log.Error(ctx).Err(err).Msg("tcptunnel: error serving local connection")
}
}()
}
}
// Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local.
func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey())
switch {
// if there is no error, or it is one of the pre-defined cliutil errors,
// then ignore and use an empty JWT
case err == nil,
errors.Is(err, cliutil.ErrExpired),
errors.Is(err, cliutil.ErrInvalid),
errors.Is(err, cliutil.ErrNotFound):
default:
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
}
return tun.run(ctx, local, rawJWT, 0)
}
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
log.Info(ctx).
Str("dst", tun.cfg.dstHost).
Str("proxy", tun.cfg.proxyHost).
Bool("secure", tun.cfg.tlsConfig != nil).
Msg("tcptunnel: opening connection")
hdr := http.Header{}
if rawJWT != "" {
hdr.Set("Authorization", "Pomerium "+rawJWT)
}
req := (&http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: tun.cfg.dstHost},
Host: tun.cfg.dstHost,
Header: hdr,
}).WithContext(ctx)
var remote net.Conn
var err error
if tun.cfg.tlsConfig != nil {
remote, err = (&tls.Dialer{Config: tun.cfg.tlsConfig}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
} else {
remote, err = (&net.Dialer{}).DialContext(ctx, "tcp", tun.cfg.proxyHost)
}
if err != nil {
return fmt.Errorf("tcptunnel: failed to establish connection to proxy: %w", err)
}
defer func() {
_ = remote.Close()
log.Info(ctx).Msg("tcptunnel: connection closed")
}()
if done := ctx.Done(); done != nil {
go func() {
<-done
_ = remote.Close()
}()
}
err = req.Write(remote)
if err != nil {
return err
}
br := bufio.NewReader(remote)
res, err := http.ReadResponse(br, req)
if err != nil {
return fmt.Errorf("tcptunnel: failed to read HTTP response: %w", err)
}
defer func() {
_ = res.Body.Close()
}()
switch res.StatusCode {
case http.StatusOK:
case http.StatusMovedPermanently,
http.StatusFound,
http.StatusTemporaryRedirect,
http.StatusPermanentRedirect:
if retryCount == 0 {
_ = remote.Close()
serverURL := &url.URL{
Scheme: "http",
Host: tun.cfg.proxyHost,
}
if tun.cfg.tlsConfig != nil {
serverURL.Scheme = "https"
}
rawJWT, err = tun.auth.GetJWT(ctx, serverURL)
if err != nil {
return fmt.Errorf("tcptunnel: failed to get authentication JWT: %w", err)
}
err = tun.cfg.jwtCache.StoreJWT(tun.jwtCacheKey(), rawJWT)
if err != nil {
return fmt.Errorf("tcptunnel: failed to store JWT: %w", err)
}
return tun.run(ctx, local, rawJWT, retryCount+1)
}
fallthrough
default:
_ = tun.cfg.jwtCache.DeleteJWT(tun.jwtCacheKey())
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
}
log.Info(ctx).Msg("tcptunnel: connection established")
errc := make(chan error, 2)
go func() {
_, err := io.Copy(remote, local)
errc <- err
}()
remoteReader := deBuffer(br, remote)
go func() {
_, err := io.Copy(local, remoteReader)
errc <- err
}()
select {
case err := <-errc:
if err != nil {
err = fmt.Errorf("tcptunnel: %w", err)
}
return err
case <-ctx.Done():
return nil
}
}
func (tun *Tunnel) jwtCacheKey() string {
return fmt.Sprintf("%s|%v", tun.cfg.proxyHost, tun.cfg.tlsConfig != nil)
}
func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader {
if br.Buffered() == 0 {
return underlying
}
return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying)
}

View file

@ -1,112 +0,0 @@
package tcptunnel
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTunnel(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
backend, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = backend.Close() }()
go func() {
for {
conn, err := backend.Accept()
if err != nil {
return
}
go func() {
defer func() { _ = conn.Close() }()
ln, _, _ := bufio.NewReader(conn).ReadLine()
assert.Equal(t, "HELLO WORLD", string(ln))
}()
}
}()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !assert.Equal(t, "CONNECT", r.Method) {
return
}
if !assert.Equal(t, "example.com:9999", r.RequestURI) {
return
}
w.WriteHeader(200)
in, brw, err := w.(http.Hijacker).Hijack()
if !assert.NoError(t, err) {
return
}
defer func() { _ = in.Close() }()
out, err := net.Dial("tcp", backend.Addr().String())
if !assert.NoError(t, err) {
return
}
defer func() { _ = out.Close() }()
errc := make(chan error, 2)
go func() {
_, err := io.Copy(in, out)
errc <- err
}()
go func() {
_, err := io.Copy(out, deBuffer(brw.Reader, in))
errc <- err
}()
<-errc
}))
defer srv.Close()
var buf bytes.Buffer
tun := New(
WithDestinationHost("example.com:9999"),
WithProxyHost(srv.Listener.Addr().String()))
err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf})
if !assert.NoError(t, err) {
return
}
}
type readWriter struct {
io.Reader
io.Writer
}
func TestForceHTTP1(t *testing.T) {
tunnel := New(WithTLSConfig(&tls.Config{
InsecureSkipVerify: true,
}))
var protocol string
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
protocol = r.Proto
}))
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tunnel.cfg.tlsConfig,
},
}
_, _ = client.Get(srv.URL)
assert.Equal(t, "HTTP/1.1", protocol)
}

View file

@ -11,7 +11,7 @@ _target="${TARGET:-"$(go env GOOS)-$(go env GOARCH)"}"
if [ "$_target" == "darwin-arm64" ]; then
echo "Using local envoy distribution for Apple M1"
cp `which envoy` "$_dir/envoy-$_target"
cp -f `which envoy` "$_dir/envoy-$_target"
(cd internal/envoy/files && sha256sum "$_dir/envoy-$_target" > "$_dir/envoy-$_target.sha256")
echo "1.21.0-dev" >"$_dir/envoy-$_target.version"
exit 0