pomerium/pkg/ssh/auth_test.go
Kenneth Jenkins 9678e6a231
ssh: implement authorization policy evaluation (#5665)
Implement the pkg/ssh.AuthInterface. Add logic for converting from the
ssh stream state to an evaluator request, and for interpreting the
results of policy evaluation. Refactor some of the existing authorize
logic to make it easier to reuse.
2025-07-01 12:04:00 -07:00

469 lines
16 KiB
Go

package ssh
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/testutil/mockidp"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestHandlePublicKeyMethodRequest(t *testing.T) {
t.Run("no public key fingerprint", func(t *testing.T) {
var a Auth
var req extensions_ssh.PublicKeyMethodRequest
_, err := a.handlePublicKeyMethodRequest(t.Context(), StreamAuthInfo{}, &req)
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
t.Run("evaluate error", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(context.Context, *Request) (*evaluator.Result, error) {
return nil, errors.New("error evaluating policy")
})
a := NewAuth(pe, nil, nil, nil)
_, err := a.handlePublicKeyMethodRequest(t.Context(), info, &req)
assert.ErrorContains(t, err, "error evaluating policy")
})
t.Run("allow", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
fakePublicKey := []byte("fake-public-key")
req.PublicKey = fakePublicKey
pe := policyEvaluatorFunc(func(_ context.Context, r *Request) (*evaluator.Result, error) {
assert.Equal(t, r, &Request{
Username: "username",
Hostname: "hostname",
PublicKey: fakePublicKey,
SessionID: "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=",
})
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Empty(t, res.RequireAdditionalMethods)
require.NotNil(t, res.Allow)
assert.Equal(t, res.Allow.PublicKey, fakePublicKey)
})
t.Run("deny", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(true),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Nil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
})
t.Run("public key unauthorized", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonSSHPublickeyUnauthorized),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.Nil(t, res.Allow)
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodPublicKey})
})
t.Run("needs login", func(t *testing.T) {
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
}, nil
})
a := NewAuth(pe, nil, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive})
})
t.Run("internal command no session", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
},
}
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr(""),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
}, nil
})
a := NewAuth(pe, client, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Equal(t, res.RequireAdditionalMethods, []string{MethodKeyboardInteractive})
})
t.Run("internal command with session", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Record: &databroker.Record{
Type: "type.googleapis.com/session.Session",
Id: "abc",
Data: protoutil.NewAny(&session.Session{
Id: "abc",
UserId: "USER-ID",
}),
},
}, nil
},
}
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr(""),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, client, nil, nil)
res, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
})
t.Run("internal command databroker error", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.Unknown, "unknown")
},
}
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr(""),
}
var req extensions_ssh.PublicKeyMethodRequest
req.PublicKeyFingerprintSha256 = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456")
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
a := NewAuth(pe, client, nil, nil)
_, err := a.HandlePublicKeyMethodRequest(t.Context(), info, &req)
assert.ErrorContains(t, err, "internal error")
})
}
func TestHandleKeyboardInteractiveMethodRequest(t *testing.T) {
t.Run("no public key", func(t *testing.T) {
var a Auth
_, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), StreamAuthInfo{}, nil)
assert.ErrorContains(t, err, "expected PublicKeyAllow message not to be nil")
})
t.Run("ok", func(t *testing.T) {
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false),
}, nil
})
var putRecords []*databroker.Record
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
},
put: func(
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
) (*databroker.PutResponse, error) {
putRecords = append(putRecords, in.Records...)
return &databroker.PutResponse{
Records: in.Records,
}, nil
},
}
cfg := config.Config{
Options: config.NewDefaultOptions(),
}
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
idpURL := mockIDP.Start(t)
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
Value: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: []byte("fake-public-key"),
},
},
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{})
require.NoError(t, err)
assert.NotNil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
// A new Session and User record should have been saved to the databroker.
assert.Len(t, putRecords, 2)
assert.Equal(t, "type.googleapis.com/user.User", putRecords[0].Type)
assert.Equal(t, "fake.user@example.com", putRecords[0].Id)
assert.Equal(t, "type.googleapis.com/session.Session", putRecords[1].Type)
assert.Equal(t, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=", putRecords[1].Id)
})
t.Run("denied", func(t *testing.T) {
pe := policyEvaluatorFunc(func(_ context.Context, _ *Request) (*evaluator.Result, error) {
return &evaluator.Result{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(false),
}, nil
})
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, _ *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
},
put: func(
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
) (*databroker.PutResponse, error) {
return &databroker.PutResponse{
Records: in.Records,
}, nil
},
}
cfg := config.Config{
Options: config.NewDefaultOptions(),
}
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
idpURL := mockIDP.Start(t)
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(pe, client, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
Value: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: []byte("fake-public-key"),
},
},
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
res, err := a.HandleKeyboardInteractiveMethodRequest(t.Context(), info, nil, noopQuerier{})
require.NoError(t, err)
assert.Nil(t, res.Allow)
assert.Empty(t, res.RequireAdditionalMethods)
})
t.Run("invalid fingerprint", func(t *testing.T) {
cfg := config.Config{
Options: config.NewDefaultOptions(),
}
mockIDP := mockidp.New(mockidp.Config{EnableDeviceAuth: true})
idpURL := mockIDP.Start(t)
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idpURL
cfg.Options.ClientID = "client-id"
cfg.Options.ClientSecret = "client-secret"
a := NewAuth(nil, nil, atomicutil.NewValue(&cfg), nil)
info := StreamAuthInfo{
Username: ptr("username"),
Hostname: ptr("hostname"),
PublicKeyAllow: AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]{
Value: &extensions_ssh.PublicKeyAllowResponse{
PublicKey: []byte("fake-public-key"),
},
},
}
_, err := a.handleKeyboardInteractiveMethodRequest(t.Context(), info, noopQuerier{})
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
}
func TestFormatSession(t *testing.T) {
t.Run("invalid fingerprint", func(t *testing.T) {
var a Auth
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("wrong-length"),
}
_, err := a.FormatSession(t.Context(), info)
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
t.Run("ok", func(t *testing.T) {
client := fakeDataBrokerServiceClient{
get: func(
_ context.Context, in *databroker.GetRequest, _ ...grpc.CallOption,
) (*databroker.GetResponse, error) {
const expectedID = "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY="
assert.Equal(t, in.Type, "type.googleapis.com/session.Session")
assert.Equal(t, in.Id, expectedID)
claims := identity.FlattenedClaims{
"foo": []any{"bar", "baz"},
"quux": []any{42},
}
return &databroker.GetResponse{
Record: &databroker.Record{
Type: "type.googleapis.com/session.Session",
Id: expectedID,
Data: protoutil.NewAny(&session.Session{
Id: expectedID,
UserId: "USER-ID",
ExpiresAt: &timestamppb.Timestamp{Seconds: 1750965358},
Claims: claims.ToPB(),
}),
},
}, nil
},
}
a := NewAuth(nil, client, nil, nil)
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
b, err := a.FormatSession(t.Context(), info)
assert.NoError(t, err)
assert.Equal(t, string(b), `
User ID: USER-ID
Session ID: sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=
Expires at: 2025-06-26 19:15:58 +0000 UTC
Claims:
foo: [bar baz]
quux: [42]
`)
})
}
func TestDeleteSession(t *testing.T) {
t.Run("invalid fingerprint", func(t *testing.T) {
var a Auth
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("wrong-length"),
}
err := a.DeleteSession(t.Context(), info)
assert.ErrorContains(t, err, "invalid public key fingerprint")
})
t.Run("ok", func(t *testing.T) {
putError := errors.New("sentinel")
client := fakeDataBrokerServiceClient{
put: func(
_ context.Context, in *databroker.PutRequest, _ ...grpc.CallOption,
) (*databroker.PutResponse, error) {
require.Len(t, in.Records, 1)
assert.Equal(t, in.Records[0].Id, "sshkey-SHA256:QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVoxMjM0NTY=")
assert.NotNil(t, in.Records[0].DeletedAt)
return nil, putError
},
}
a := NewAuth(nil, client, nil, nil)
info := StreamAuthInfo{
PublicKeyFingerprintSha256: []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ123456"),
}
err := a.DeleteSession(t.Context(), info)
assert.Equal(t, putError, err)
})
}
type policyEvaluatorFunc func(context.Context, *Request) (*evaluator.Result, error)
func (f policyEvaluatorFunc) EvaluateSSH(
ctx context.Context, req *Request,
) (*evaluator.Result, error) {
return f(ctx, req)
}
type fakeDataBrokerServiceClient struct {
databroker.DataBrokerServiceClient
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
}
func (m fakeDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return m.get(ctx, in, opts...)
}
func (m fakeDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
return m.put(ctx, in, opts...)
}
type noopQuerier struct{}
func (noopQuerier) Prompt(
_ context.Context, _ *extensions_ssh.KeyboardInteractiveInfoPrompts,
) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
return nil, nil
}
func ptr[T any](t T) *T {
return &t
}