authenticate: validate origin of signout (#1876)

* authenticate: validate origin of signout

- add a debug task to kill envoy
- improve various function docs
- userinfo: return "error" page if user is logged out without redirect uri set
- remove front channel logout. There's little difference between it, and the signout function.

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
bobby 2021-02-11 21:37:54 -08:00 committed by GitHub
parent 9fd58f9b8a
commit c3e3ed9b50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 174 additions and 182 deletions

View file

@ -13,10 +13,18 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
@ -33,15 +41,6 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthenticate() *Authenticate {
@ -106,7 +105,7 @@ func TestAuthenticate_Handler(t *testing.T) {
expected = fmt.Sprintf("User-agent: *\nDisallow: /")
code := rr.Code
if code != http.StatusOK {
t.Errorf("bad preflight code")
t.Errorf("bad preflight code %v", code)
}
resp := rr.Result()
body = resp.Header.Get("vary")
@ -235,6 +234,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, ""},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, ""},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, ""},
{"no redirect uri", http.MethodPost, nil, "", "", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusOK, "{\"Status\":200,\"Error\":\"OK: user logged out\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -566,7 +566,7 @@ func TestWellKnownEndpoint(t *testing.T) {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
body := rr.Body.String()
expected := "{\"jwks_uri\":\"https://auth.example.com/.well-known/pomerium/jwks.json\",\"authentication_callback_endpoint\":\"https://auth.example.com/oauth2/callback\",\"frontchannel_logout_uri\":\"https://auth.example.com/.pomerium/frontchannel-logout\"}\n"
expected := "{\"authentication_callback_endpoint\":\"https://auth.example.com/oauth2/callback\",\"jwks_uri\":\"https://auth.example.com/.well-known/pomerium/jwks.json\",\"frontchannel_logout_uri\":\"https://auth.example.com/.pomerium/sign_out\"}\n"
assert.Equal(t, body, expected)
}
@ -669,84 +669,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
}
}
func TestAuthenticate_FrontchannelLogout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
logoutIssuer string
tokenIssuer string
widthSession bool
sessionStore sessions.SessionStore
provider identity.MockProvider
wantCode int
}{
{"good", "https://idp.pomerium.io", "https://idp.pomerium.io", true, &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, http.StatusOK},
{"good no session", "https://idp.pomerium.io", "https://idp.pomerium.io", false, &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
a := &Authenticate{
state: newAtomicAuthenticateState(&authenticateState{
sessionStore: tt.sessionStore,
encryptedEncoder: mock.Encoder{},
sharedEncoder: mock.Encoder{},
dataBrokerClient: mockDataBrokerServiceClient{
delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
return nil, nil
},
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
if !tt.widthSession {
return nil, nil
}
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
IdToken: &session.IDToken{
Issuer: tt.tokenIssuer,
},
})
if err != nil {
return nil, err
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
},
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(),
}
a.provider.Store(tt.provider)
u, _ := url.Parse("/.pomerium/frontchannel-logout")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("iss", tt.logoutIssuer)
u.RawQuery = params.Encode()
r := httptest.NewRequest(http.MethodGet, u.String(), nil)
w := httptest.NewRecorder()
httputil.HandlerFunc(a.FrontchannelLogout).ServeHTTP(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
})
}
}
type mockDataBrokerServiceClient struct {
databroker.DataBrokerServiceClient
@ -779,3 +701,87 @@ func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directo
}
return nil, status.Error(codes.Unimplemented, "")
}
func TestAuthenticate_SignOut_CSRF(t *testing.T) {
now := time.Now()
signer, err := jws.NewHS256Signer(nil)
if err != nil {
t.Fatal(err)
}
pbNow, _ := ptypes.TimestampProto(now)
a := &Authenticate{
options: config.NewAtomicOptions(),
state: newAtomicAuthenticateState(&authenticateState{
// sessionStore: tt.sessionStore,
cookieSecret: cryptutil.NewKey(),
encryptedEncoder: signer,
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
IdToken: &session.IDToken{IssuedAt: pbNow},
})
if err != nil {
return nil, err
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
},
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
templates: template.Must(frontend.NewTemplates()),
}
tests := []struct {
name string
setCSRFCookie bool
method string
wantStatus int
wantBody string
}{
{"GET without CSRF should fail", false, "GET", 400, "{\"Status\":400,\"Error\":\"Bad Request: CSRF token invalid\"}\n"},
{"POST without CSRF should fail", false, "POST", 400, "{\"Status\":400,\"Error\":\"Bad Request: CSRF token invalid\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := a.Handler()
// Obtain a CSRF cookie via a GET request.
orr, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
s.ServeHTTP(rr, orr)
r, err := http.NewRequest(tt.method, "/.pomerium/sign_out", nil)
if err != nil {
t.Fatal(err)
}
if tt.setCSRFCookie {
r.Header.Set("Cookie", rr.Header().Get("Set-Cookie"))
}
r.Header.Set("Accept", "application/json")
r.Header.Set("Referer", "/")
rr = httptest.NewRecorder()
s.ServeHTTP(rr, r)
if rr.Code != tt.wantStatus {
t.Errorf("status: got %v want %v", rr.Code, tt.wantStatus)
}
body := rr.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("handler returned wrong body Body: %s", diff)
}
})
}
}