remove :443 or :80 from proxy URLs in authclient (#1733)

* remove :443 or :80 from proxy URLs in authclient

* handle buffered bytes
This commit is contained in:
Caleb Doxsey 2021-01-04 16:06:24 -07:00 committed by GitHub
parent f837c92741
commit 672b9c7a72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 4 deletions

View file

@ -94,7 +94,17 @@ func (client *AuthClient) runHTTPServer(ctx context.Context, li net.Listener, in
} }
func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error { func (client *AuthClient) runOpenBrowser(ctx context.Context, li net.Listener, serverURL *url.URL) error {
dst := serverURL.ResolveReference(&url.URL{ browserURL := new(url.URL)
*browserURL = *serverURL
// remove unnecessary ports to avoid HMAC error
if browserURL.Scheme == "http" && browserURL.Host == browserURL.Hostname()+":80" {
browserURL.Host = browserURL.Hostname()
} else if browserURL.Scheme == "https" && browserURL.Host == browserURL.Hostname()+":443" {
browserURL.Host = browserURL.Hostname()
}
dst := browserURL.ResolveReference(&url.URL{
Path: "/.pomerium/api/v1/login", Path: "/.pomerium/api/v1/login",
RawQuery: url.Values{ RawQuery: url.Values{
"pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())}, "pomerium_redirect_uri": {fmt.Sprintf("http://%s", li.Addr().String())},

View file

@ -197,7 +197,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
errc <- err errc <- err
}() }()
go func() { go func() {
_, err := io.Copy(local, remote) _, err := io.Copy(local, deBuffer(br, remote))
errc <- err errc <- err
}() }()
@ -215,3 +215,10 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
func (tun *Tunnel) jwtCacheKey() string { func (tun *Tunnel) jwtCacheKey() string {
return fmt.Sprintf("%s|%s|%v", tun.cfg.dstHost, tun.cfg.proxyHost, tun.cfg.tlsConfig != nil) return fmt.Sprintf("%s|%s|%v", tun.cfg.dstHost, tun.cfg.proxyHost, tun.cfg.tlsConfig != nil)
} }
func deBuffer(br *bufio.Reader, underlying io.Reader) io.Reader {
if br.Buffered() == 0 {
return underlying
}
return io.MultiReader(io.LimitReader(br, int64(br.Buffered())), underlying)
}

View file

@ -50,7 +50,7 @@ func TestTunnel(t *testing.T) {
w.WriteHeader(200) w.WriteHeader(200)
in, _, err := w.(http.Hijacker).Hijack() in, brw, err := w.(http.Hijacker).Hijack()
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@ -68,7 +68,7 @@ func TestTunnel(t *testing.T) {
errc <- err errc <- err
}() }()
go func() { go func() {
_, err := io.Copy(out, in) _, err := io.Copy(out, deBuffer(brw.Reader, in))
errc <- err errc <- err
}() }()
<-errc <-errc