pomerium/pkg/ssh/stream.go
Kenneth Jenkins 9678e6a231
ssh: implement authorization policy evaluation (#5665)
Implement the pkg/ssh.AuthInterface. Add logic for converting from the
ssh stream state to an evaluator request, and for interpreting the
results of policy evaluation. Refactor some of the existing authorize
logic to make it easier to reuse.
2025-07-01 12:04:00 -07:00

487 lines
16 KiB
Go

package ssh
import (
"context"
"iter"
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
gossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/slices"
)
const (
MethodPublicKey = "publickey"
MethodKeyboardInteractive = "keyboard-interactive"
)
type KeyboardInteractiveQuerier interface {
// Prompts the client and returns their responses to the given prompts.
Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error)
}
type AuthMethodResponse[T any] struct {
Allow *T
RequireAdditionalMethods []string
}
type (
PublicKeyAuthMethodResponse = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse]
KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
)
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
type AuthInterface interface {
HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error
FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error)
DeleteSession(ctx context.Context, info StreamAuthInfo) error
}
type AuthMethodValue[T any] struct {
attempted bool
Value *T
}
func (v *AuthMethodValue[T]) Update(value *T) {
v.attempted = true
v.Value = value
}
func (v *AuthMethodValue[T]) IsValid() bool {
if v.attempted {
// method was attempted - valid iff there is a value
return v.Value != nil
}
return true // method was not attempted - valid
}
type StreamAuthInfo struct {
Username *string
Hostname *string
StreamID uint64
SourceAddress string
PublicKeyFingerprintSha256 []byte
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
}
func (i *StreamAuthInfo) allMethodsValid() bool {
return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid()
}
type StreamState struct {
StreamAuthInfo
DirectTcpip bool
RemainingUnauthenticatedMethods []string
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
}
// StreamHandler handles a single SSH stream
type StreamHandler struct {
auth AuthInterface
config *config.Config
downstream *extensions_ssh.DownstreamConnectEvent
writeC chan *extensions_ssh.ServerMessage
readC chan *extensions_ssh.ClientMessage
state *StreamState
close func()
channelIDCounter uint32
expectingInternalChannel bool
}
var _ StreamHandlerInterface = (*StreamHandler)(nil)
func (sh *StreamHandler) Close() {
sh.close()
}
func (sh *StreamHandler) IsExpectingInternalChannel() bool {
return sh.expectingInternalChannel
}
func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage {
return sh.readC
}
func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
return sh.writeC
}
// Prompt implements KeyboardInteractiveQuerier.
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
sh.sendInfoPrompts(prompts)
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case req := <-sh.readC:
switch msg := req.Message.(type) {
case *extensions_ssh.ClientMessage_InfoResponse:
if msg.InfoResponse.Method != "keyboard-interactive" {
return nil, status.Errorf(codes.Internal, "received invalid info response")
}
r, _ := msg.InfoResponse.Response.UnmarshalNew()
respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response")
}
return respInfo, nil
default:
return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")
}
}
}
func (sh *StreamHandler) Run(ctx context.Context) error {
if sh.state != nil {
panic("Run called twice")
}
sh.state = &StreamState{
RemainingUnauthenticatedMethods: []string{MethodPublicKey},
StreamAuthInfo: StreamAuthInfo{
StreamID: sh.downstream.StreamId,
SourceAddress: sh.downstream.SourceAddress,
},
}
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case req := <-sh.readC:
switch req := req.Message.(type) {
case *extensions_ssh.ClientMessage_Event:
switch event := req.Event.Event.(type) {
case *extensions_ssh.StreamEvent_DownstreamConnected:
// this was already received as the first message in the stream
return status.Errorf(codes.Internal, "received duplicate downstream connected event")
case *extensions_ssh.StreamEvent_UpstreamConnected:
log.Ctx(ctx).Debug().
Uint64("stream-id", event.UpstreamConnected.StreamId).
Msg("ssh: upstream connected")
case *extensions_ssh.StreamEvent_DownstreamDisconnected:
log.Ctx(ctx).Debug().
Uint64("stream-id", sh.downstream.StreamId).
Str("reason", event.DownstreamDisconnected.Reason).
Msg("ssh: downstream disconnected")
case nil:
return status.Errorf(codes.Internal, "received invalid event")
}
case *extensions_ssh.ClientMessage_AuthRequest:
if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil {
return err
}
default:
return status.Errorf(codes.Internal, "received invalid message")
}
}
}
}
func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error {
// The first channel message on this stream should be a ChannelOpen
channelOpen, err := stream.Recv()
if err != nil {
return err
}
rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes)
if !ok {
return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
}
var msg ChannelOpenMsg
if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil {
return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
}
sh.channelIDCounter++
sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
ChannelType: msg.ChanType,
DownstreamChannelId: msg.PeersID,
InternalUpstreamChannelId: sh.channelIDCounter,
InitialWindowSize: msg.PeersWindow,
MaxPacketSize: msg.MaxPacketSize,
}
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
switch msg.ChanType {
case "session":
if err := channel.SendMessage(ChannelOpenConfirmMsg{
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
MyWindow: ChannelWindowSize,
MaxPacketSize: ChannelMaxPacket,
}); err != nil {
return err
}
ch := NewChannelHandler(channel, sh.config)
return ch.Run(stream.Context())
case "direct-tcpip":
var subMsg ChannelOpenDirectMsg
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
return err
}
sh.state.DirectTcpip = true
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
if err != nil {
return err
}
return channel.SendControlAction(action)
default:
return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType)
}
}
func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error {
if req.Protocol != "ssh" {
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
}
if req.Service != "ssh-connection" {
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
}
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod)
}
if sh.state.Username == nil {
if req.Username == "" {
return status.Errorf(codes.InvalidArgument, "username missing")
}
sh.state.Username = &req.Username
} else if *sh.state.Username != req.Username {
return status.Errorf(codes.InvalidArgument, "inconsistent username")
}
if sh.state.Hostname == nil {
sh.state.Hostname = &req.Hostname
} else if *sh.state.Hostname != req.Hostname {
return status.Errorf(codes.InvalidArgument, "inconsistent hostname")
}
updateMethods := func(add []string) {
sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod)
sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...)
}
log.Ctx(ctx).Debug().
Str("method", req.AuthMethod).
Str("username", *sh.state.Username).
Str("hostname", *sh.state.Hostname).
Msg("ssh: handling auth request")
var partial bool
switch req.AuthMethod {
case MethodPublicKey:
methodReq, _ := req.MethodRequest.UnmarshalNew()
pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
if !ok {
return status.Errorf(codes.InvalidArgument, "invalid public key method request type")
}
response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
if err != nil {
return err
} else if response.Allow != nil {
partial = true
sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
}
sh.state.PublicKeyAllow.Update(response.Allow)
updateMethods(response.RequireAdditionalMethods)
case MethodKeyboardInteractive:
methodReq, _ := req.MethodRequest.UnmarshalNew()
kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest)
if !ok {
return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type")
}
response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh)
if err != nil {
return err
}
partial = response.Allow != nil
sh.state.KeyboardInteractiveAllow.Update(response.Allow)
updateMethods(response.RequireAdditionalMethods)
default:
return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod)
}
log.Ctx(ctx).Debug().
Str("method", req.AuthMethod).
Bool("partial", partial).
Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods).
Msg("ssh: auth request complete")
if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
// if there are no methods remaining, the user is allowed if all attempted
// methods have a valid response in the state
log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
sh.sendAllowResponse()
} else {
log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response")
sh.sendDenyResponseWithRemainingMethods(partial)
}
return nil
}
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
if hostname == "" {
return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
}
if sh.state.Hostname == nil {
panic("bug: PrepareHandoff called but state is missing a hostname")
}
if *sh.state.Hostname != "" {
panic("bug: PrepareHandoff called but previous hostname is not empty")
}
*sh.state.Hostname = hostname
err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
if err != nil {
return nil, status.Error(codes.PermissionDenied, err.Error())
}
log.Ctx(ctx).Debug().
Str("hostname", *sh.state.Hostname).
Str("username", *sh.state.Username).
Msg("ssh: initiating handoff to upstream")
upstreamAllow := sh.buildUpstreamAllowResponse()
action := &extensions_ssh.SSHChannelControlAction{
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
DownstreamChannelInfo: sh.state.DownstreamChannelInfo,
DownstreamPtyInfo: ptyInfo,
UpstreamAuth: upstreamAllow,
},
},
}
return action, nil
}
func (sh *StreamHandler) FormatSession(ctx context.Context) ([]byte, error) {
return sh.auth.FormatSession(ctx, sh.state.StreamAuthInfo)
}
func (sh *StreamHandler) DeleteSession(ctx context.Context) error {
return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo)
}
func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] {
return func(yield func(*config.Policy) bool) {
for route := range sh.config.Options.GetAllPolicies() {
if route.IsSSH() {
if !yield(route) {
return
}
}
}
}
}
// DownstreamChannelID implements StreamHandlerInterface.
func (sh *StreamHandler) DownstreamChannelID() uint32 {
return sh.state.DownstreamChannelInfo.DownstreamChannelId
}
// Hostname implements StreamHandlerInterface.
func (sh *StreamHandler) Hostname() *string {
return sh.state.Hostname
}
// Username implements StreamHandlerInterface.
func (sh *StreamHandler) Username() *string {
return sh.state.Username
}
func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) {
sh.writeC <- &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Deny{
Deny: &extensions_ssh.DenyResponse{
Partial: partial,
Methods: sh.state.RemainingUnauthenticatedMethods,
},
},
},
},
}
}
func (sh *StreamHandler) sendAllowResponse() {
var allow *extensions_ssh.AllowResponse
if *sh.state.Hostname == "" {
sh.expectingInternalChannel = true
allow = sh.buildInternalAllowResponse()
} else {
allow = sh.buildUpstreamAllowResponse()
}
sh.writeC <- &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_Allow{
Allow: allow,
},
},
},
}
}
func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) {
sh.writeC <- &extensions_ssh.ServerMessage{
Message: &extensions_ssh.ServerMessage_AuthResponse{
AuthResponse: &extensions_ssh.AuthenticationResponse{
Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
InfoRequest: &extensions_ssh.InfoRequest{
Method: MethodKeyboardInteractive,
Request: protoutil.NewAny(prompts),
},
},
},
},
}
}
func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse {
var allowedMethods []*extensions_ssh.AllowedMethod
if value := sh.state.PublicKeyAllow.Value; value != nil {
allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
Method: MethodPublicKey,
MethodData: protoutil.NewAny(value),
})
}
if value := sh.state.KeyboardInteractiveAllow.Value; value != nil {
allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
Method: MethodKeyboardInteractive,
MethodData: protoutil.NewAny(value),
})
}
return &extensions_ssh.AllowResponse{
Username: *sh.state.Username,
Target: &extensions_ssh.AllowResponse_Upstream{
Upstream: &extensions_ssh.UpstreamTarget{
Hostname: *sh.state.Hostname,
DirectTcpip: sh.state.DirectTcpip,
AllowedMethods: allowedMethods,
},
},
}
}
func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse {
return &extensions_ssh.AllowResponse{
Username: *sh.state.Username,
Target: &extensions_ssh.AllowResponse_Internal{
Internal: &extensions_ssh.InternalTarget{
SetMetadata: &corev3.Metadata{
TypedFilterMetadata: map[string]*anypb.Any{
"com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{
StreamId: sh.downstream.StreamId,
}),
},
},
},
},
}
}