mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
149 lines
3.5 KiB
Go
149 lines
3.5 KiB
Go
// Package authclient contains a CLI authentication client for Pomerium.
|
|
package authclient
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
// An AuthClient retrieves an authentication JWT via the Pomerium login API.
|
|
type AuthClient struct {
|
|
cfg *config
|
|
}
|
|
|
|
// New creates a new AuthClient.
|
|
func New(options ...Option) *AuthClient {
|
|
return &AuthClient{
|
|
cfg: getConfig(options...),
|
|
}
|
|
}
|
|
|
|
// GetJWT retrieves a JWT from Pomerium.
|
|
func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL) (rawJWT string, err error) {
|
|
li, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to start listener: %w", err)
|
|
}
|
|
defer func() { _ = li.Close() }()
|
|
|
|
incomingJWT := make(chan string)
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
eg.Go(func() error {
|
|
return client.runHTTPServer(ctx, li, incomingJWT)
|
|
})
|
|
eg.Go(func() error {
|
|
return client.runOpenBrowser(ctx, li, serverURL)
|
|
})
|
|
eg.Go(func() error {
|
|
select {
|
|
case rawJWT = <-incomingJWT:
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
return nil
|
|
})
|
|
err = eg.Wait()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return rawJWT, nil
|
|
}
|
|
|
|
func (client *AuthClient) runHTTPServer(ctx context.Context, li net.Listener, incomingJWT chan string) error {
|
|
var srv *http.Server
|
|
srv = &http.Server{
|
|
BaseContext: func(li net.Listener) context.Context {
|
|
return ctx
|
|
},
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
jwt := r.FormValue("pomerium_jwt")
|
|
if jwt == "" {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
incomingJWT <- jwt
|
|
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
_, _ = io.WriteString(w, "login complete, you may close this page")
|
|
|
|
go func() { _ = srv.Shutdown(ctx) }()
|
|
}),
|
|
}
|
|
// shutdown the server when ctx is done.
|
|
go func() {
|
|
<-ctx.Done()
|
|
_ = srv.Shutdown(ctx)
|
|
}()
|
|
err := srv.Serve(li)
|
|
if err == http.ErrServerClosed {
|
|
err = nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error {
|
|
browserURL := new(url.URL)
|
|
*browserURL = *serverURL
|
|
|
|
// remove unnecessary ports to avoid HMAC error
|
|
if browserURL.Scheme == "http" && browserURL.Host == browserURL.Hostname()+":80" {
|
|
browserURL.Host = browserURL.Hostname()
|
|
} else if browserURL.Scheme == "https" && browserURL.Host == browserURL.Hostname()+":443" {
|
|
browserURL.Host = browserURL.Hostname()
|
|
}
|
|
|
|
dst := browserURL.ResolveReference(&url.URL{
|
|
Path: "/.pomerium/api/v1/login",
|
|
RawQuery: url.Values{
|
|
"pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())},
|
|
}.Encode(),
|
|
})
|
|
|
|
ctx, clearTimeout := context.WithTimeout(ctx, 10*time.Second)
|
|
defer clearTimeout()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", dst.String(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.TLSClientConfig = client.cfg.tlsConfig
|
|
|
|
hc := &http.Client{
|
|
Transport: transport,
|
|
}
|
|
|
|
res, err := hc.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get login url: %w", err)
|
|
}
|
|
defer func() { _ = res.Body.Close() }()
|
|
|
|
if res.StatusCode/100 != 2 {
|
|
return fmt.Errorf("failed to get login url: %s", res.Status)
|
|
}
|
|
|
|
bs, err := ioutil.ReadAll(res.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read login url: %w", err)
|
|
}
|
|
|
|
err = client.cfg.open(string(bs))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open browser url: %w", err)
|
|
}
|
|
|
|
_, _ = fmt.Fprintf(os.Stderr, "Your browser has been opened to visit:\n\n%s\n\n", string(bs))
|
|
return nil
|
|
}
|