mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-19 12:07:18 +02:00
authenticate/proxy: add backend refresh (#438)
This commit is contained in:
parent
9a330613aa
commit
ec029c679b
35 changed files with 1226 additions and 445 deletions
|
@ -1,7 +1,11 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
|
@ -30,23 +34,82 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
|||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
||||
defer span.End()
|
||||
|
||||
if s, err := sessions.FromContext(ctx); err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
if s != nil && s.Programmatic {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
_, err := sessions.FromContext(ctx)
|
||||
if errors.Is(err, sessions.ErrExpired) {
|
||||
ctx, err = p.refresh(ctx, w, r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
|
||||
return p.redirectToSignin(w, r)
|
||||
}
|
||||
signinURL := *p.authenticateSigninURL
|
||||
q := signinURL.Query()
|
||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||
signinURL.RawQuery = q.Encode()
|
||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||
log.FromRequest(r).Info().Msg("proxy: refresh success")
|
||||
} else if err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: session state")
|
||||
return p.redirectToSignin(w, r)
|
||||
}
|
||||
p.addPomeriumHeaders(w, r)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.AuthenticateSession/refresh")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(ctx)
|
||||
if !errors.Is(err, sessions.ErrExpired) || s == nil {
|
||||
return nil, errors.New("proxy: unexpected session state for refresh")
|
||||
}
|
||||
// 1 - build a signed url to call refresh on authenticate service
|
||||
refreshURI := *p.authenticateRefreshURL
|
||||
q := refreshURI.Query()
|
||||
q.Set("ati", s.AccessTokenID) // hash value points to parent token
|
||||
q.Set("aud", urlutil.StripPort(r.Host)) // request's audience, this route
|
||||
refreshURI.RawQuery = q.Encode()
|
||||
signedRefreshURL := urlutil.NewSignedURL(p.SharedKey, &refreshURI).String()
|
||||
|
||||
// 2 - http call to authenticate service
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, signedRefreshURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: backend refresh: new request: %v", err)
|
||||
}
|
||||
res, err := httputil.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: fetch %v: %w", signedRefreshURL, err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
jwtBytes, err := ioutil.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3 - save refreshed session to the client's session store
|
||||
if err = p.sessionStore.SaveSession(w, r, jwtBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 4 - add refreshed session to the current request context
|
||||
var state sessions.State
|
||||
if err := p.encoder.Unmarshal(jwtBytes, &state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := state.Verify(urlutil.StripPort(r.Host)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessions.NewContext(r.Context(), &state, err), nil
|
||||
}
|
||||
|
||||
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if s != nil && err != nil && s.Programmatic {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
signinURL := *p.authenticateSigninURL
|
||||
q := signinURL.Query()
|
||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||
signinURL.RawQuery = q.Encode()
|
||||
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
|
||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -61,8 +124,8 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// AuthorizeSession is middleware to enforce a user is authorized for a request
|
||||
// session state is retrieved from the users's request context.
|
||||
// AuthorizeSession is middleware to enforce a user is authorized for a request.
|
||||
// Session state is retrieved from the users's request context.
|
||||
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue