mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
Implement the pkg/ssh.AuthInterface. Add logic for converting from the ssh stream state to an evaluator request, and for interpreting the results of policy evaluation. Refactor some of the existing authorize logic to make it easier to reuse.
487 lines
16 KiB
Go
487 lines
16 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"iter"
|
|
|
|
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
|
|
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
"github.com/pomerium/pomerium/pkg/slices"
|
|
)
|
|
|
|
const (
|
|
MethodPublicKey = "publickey"
|
|
MethodKeyboardInteractive = "keyboard-interactive"
|
|
)
|
|
|
|
type KeyboardInteractiveQuerier interface {
|
|
// Prompts the client and returns their responses to the given prompts.
|
|
Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error)
|
|
}
|
|
|
|
type AuthMethodResponse[T any] struct {
|
|
Allow *T
|
|
RequireAdditionalMethods []string
|
|
}
|
|
|
|
type (
|
|
PublicKeyAuthMethodResponse = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse]
|
|
KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
|
|
)
|
|
|
|
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
|
|
|
|
type AuthInterface interface {
|
|
HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
|
|
HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
|
|
EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error
|
|
FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error)
|
|
DeleteSession(ctx context.Context, info StreamAuthInfo) error
|
|
}
|
|
|
|
type AuthMethodValue[T any] struct {
|
|
attempted bool
|
|
Value *T
|
|
}
|
|
|
|
func (v *AuthMethodValue[T]) Update(value *T) {
|
|
v.attempted = true
|
|
v.Value = value
|
|
}
|
|
|
|
func (v *AuthMethodValue[T]) IsValid() bool {
|
|
if v.attempted {
|
|
// method was attempted - valid iff there is a value
|
|
return v.Value != nil
|
|
}
|
|
return true // method was not attempted - valid
|
|
}
|
|
|
|
type StreamAuthInfo struct {
|
|
Username *string
|
|
Hostname *string
|
|
StreamID uint64
|
|
SourceAddress string
|
|
PublicKeyFingerprintSha256 []byte
|
|
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
|
|
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
|
|
}
|
|
|
|
func (i *StreamAuthInfo) allMethodsValid() bool {
|
|
return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid()
|
|
}
|
|
|
|
type StreamState struct {
|
|
StreamAuthInfo
|
|
DirectTcpip bool
|
|
RemainingUnauthenticatedMethods []string
|
|
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
|
}
|
|
|
|
// StreamHandler handles a single SSH stream
|
|
type StreamHandler struct {
|
|
auth AuthInterface
|
|
config *config.Config
|
|
downstream *extensions_ssh.DownstreamConnectEvent
|
|
writeC chan *extensions_ssh.ServerMessage
|
|
readC chan *extensions_ssh.ClientMessage
|
|
|
|
state *StreamState
|
|
close func()
|
|
|
|
channelIDCounter uint32
|
|
expectingInternalChannel bool
|
|
}
|
|
|
|
var _ StreamHandlerInterface = (*StreamHandler)(nil)
|
|
|
|
func (sh *StreamHandler) Close() {
|
|
sh.close()
|
|
}
|
|
|
|
func (sh *StreamHandler) IsExpectingInternalChannel() bool {
|
|
return sh.expectingInternalChannel
|
|
}
|
|
|
|
func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage {
|
|
return sh.readC
|
|
}
|
|
|
|
func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
|
|
return sh.writeC
|
|
}
|
|
|
|
// Prompt implements KeyboardInteractiveQuerier.
|
|
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
|
|
sh.sendInfoPrompts(prompts)
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, context.Cause(ctx)
|
|
case req := <-sh.readC:
|
|
switch msg := req.Message.(type) {
|
|
case *extensions_ssh.ClientMessage_InfoResponse:
|
|
if msg.InfoResponse.Method != "keyboard-interactive" {
|
|
return nil, status.Errorf(codes.Internal, "received invalid info response")
|
|
}
|
|
r, _ := msg.InfoResponse.Response.UnmarshalNew()
|
|
respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
|
|
if !ok {
|
|
return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response")
|
|
}
|
|
return respInfo, nil
|
|
default:
|
|
return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) Run(ctx context.Context) error {
|
|
if sh.state != nil {
|
|
panic("Run called twice")
|
|
}
|
|
sh.state = &StreamState{
|
|
RemainingUnauthenticatedMethods: []string{MethodPublicKey},
|
|
StreamAuthInfo: StreamAuthInfo{
|
|
StreamID: sh.downstream.StreamId,
|
|
SourceAddress: sh.downstream.SourceAddress,
|
|
},
|
|
}
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return context.Cause(ctx)
|
|
case req := <-sh.readC:
|
|
switch req := req.Message.(type) {
|
|
case *extensions_ssh.ClientMessage_Event:
|
|
switch event := req.Event.Event.(type) {
|
|
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
|
// this was already received as the first message in the stream
|
|
return status.Errorf(codes.Internal, "received duplicate downstream connected event")
|
|
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
|
log.Ctx(ctx).Debug().
|
|
Uint64("stream-id", event.UpstreamConnected.StreamId).
|
|
Msg("ssh: upstream connected")
|
|
case *extensions_ssh.StreamEvent_DownstreamDisconnected:
|
|
log.Ctx(ctx).Debug().
|
|
Uint64("stream-id", sh.downstream.StreamId).
|
|
Str("reason", event.DownstreamDisconnected.Reason).
|
|
Msg("ssh: downstream disconnected")
|
|
case nil:
|
|
return status.Errorf(codes.Internal, "received invalid event")
|
|
}
|
|
case *extensions_ssh.ClientMessage_AuthRequest:
|
|
if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return status.Errorf(codes.Internal, "received invalid message")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error {
|
|
// The first channel message on this stream should be a ChannelOpen
|
|
channelOpen, err := stream.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes)
|
|
if !ok {
|
|
return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
|
|
}
|
|
var msg ChannelOpenMsg
|
|
if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil {
|
|
return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
|
|
}
|
|
|
|
sh.channelIDCounter++
|
|
sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
|
ChannelType: msg.ChanType,
|
|
DownstreamChannelId: msg.PeersID,
|
|
InternalUpstreamChannelId: sh.channelIDCounter,
|
|
InitialWindowSize: msg.PeersWindow,
|
|
MaxPacketSize: msg.MaxPacketSize,
|
|
}
|
|
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
|
|
switch msg.ChanType {
|
|
case "session":
|
|
if err := channel.SendMessage(ChannelOpenConfirmMsg{
|
|
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
|
|
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
|
|
MyWindow: ChannelWindowSize,
|
|
MaxPacketSize: ChannelMaxPacket,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
ch := NewChannelHandler(channel, sh.config)
|
|
return ch.Run(stream.Context())
|
|
case "direct-tcpip":
|
|
var subMsg ChannelOpenDirectMsg
|
|
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
|
return err
|
|
}
|
|
sh.state.DirectTcpip = true
|
|
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return channel.SendControlAction(action)
|
|
default:
|
|
return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType)
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error {
|
|
if req.Protocol != "ssh" {
|
|
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
|
|
}
|
|
if req.Service != "ssh-connection" {
|
|
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
|
|
}
|
|
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
|
|
return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod)
|
|
}
|
|
|
|
if sh.state.Username == nil {
|
|
if req.Username == "" {
|
|
return status.Errorf(codes.InvalidArgument, "username missing")
|
|
}
|
|
sh.state.Username = &req.Username
|
|
} else if *sh.state.Username != req.Username {
|
|
return status.Errorf(codes.InvalidArgument, "inconsistent username")
|
|
}
|
|
if sh.state.Hostname == nil {
|
|
sh.state.Hostname = &req.Hostname
|
|
} else if *sh.state.Hostname != req.Hostname {
|
|
return status.Errorf(codes.InvalidArgument, "inconsistent hostname")
|
|
}
|
|
|
|
updateMethods := func(add []string) {
|
|
sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod)
|
|
sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...)
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Str("method", req.AuthMethod).
|
|
Str("username", *sh.state.Username).
|
|
Str("hostname", *sh.state.Hostname).
|
|
Msg("ssh: handling auth request")
|
|
|
|
var partial bool
|
|
switch req.AuthMethod {
|
|
case MethodPublicKey:
|
|
methodReq, _ := req.MethodRequest.UnmarshalNew()
|
|
pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
|
|
if !ok {
|
|
return status.Errorf(codes.InvalidArgument, "invalid public key method request type")
|
|
}
|
|
response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
|
|
if err != nil {
|
|
return err
|
|
} else if response.Allow != nil {
|
|
partial = true
|
|
sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
|
|
}
|
|
sh.state.PublicKeyAllow.Update(response.Allow)
|
|
updateMethods(response.RequireAdditionalMethods)
|
|
case MethodKeyboardInteractive:
|
|
methodReq, _ := req.MethodRequest.UnmarshalNew()
|
|
kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest)
|
|
if !ok {
|
|
return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type")
|
|
}
|
|
response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partial = response.Allow != nil
|
|
sh.state.KeyboardInteractiveAllow.Update(response.Allow)
|
|
updateMethods(response.RequireAdditionalMethods)
|
|
default:
|
|
return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod)
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Str("method", req.AuthMethod).
|
|
Bool("partial", partial).
|
|
Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods).
|
|
Msg("ssh: auth request complete")
|
|
|
|
if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
|
|
// if there are no methods remaining, the user is allowed if all attempted
|
|
// methods have a valid response in the state
|
|
log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
|
|
sh.sendAllowResponse()
|
|
} else {
|
|
log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response")
|
|
sh.sendDenyResponseWithRemainingMethods(partial)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
|
|
if hostname == "" {
|
|
return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
|
|
}
|
|
if sh.state.Hostname == nil {
|
|
panic("bug: PrepareHandoff called but state is missing a hostname")
|
|
}
|
|
if *sh.state.Hostname != "" {
|
|
panic("bug: PrepareHandoff called but previous hostname is not empty")
|
|
}
|
|
*sh.state.Hostname = hostname
|
|
err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
|
|
if err != nil {
|
|
return nil, status.Error(codes.PermissionDenied, err.Error())
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Str("hostname", *sh.state.Hostname).
|
|
Str("username", *sh.state.Username).
|
|
Msg("ssh: initiating handoff to upstream")
|
|
upstreamAllow := sh.buildUpstreamAllowResponse()
|
|
action := &extensions_ssh.SSHChannelControlAction{
|
|
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
|
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
|
DownstreamChannelInfo: sh.state.DownstreamChannelInfo,
|
|
DownstreamPtyInfo: ptyInfo,
|
|
UpstreamAuth: upstreamAllow,
|
|
},
|
|
},
|
|
}
|
|
return action, nil
|
|
}
|
|
|
|
func (sh *StreamHandler) FormatSession(ctx context.Context) ([]byte, error) {
|
|
return sh.auth.FormatSession(ctx, sh.state.StreamAuthInfo)
|
|
}
|
|
|
|
func (sh *StreamHandler) DeleteSession(ctx context.Context) error {
|
|
return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo)
|
|
}
|
|
|
|
func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] {
|
|
return func(yield func(*config.Policy) bool) {
|
|
for route := range sh.config.Options.GetAllPolicies() {
|
|
if route.IsSSH() {
|
|
if !yield(route) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// DownstreamChannelID implements StreamHandlerInterface.
|
|
func (sh *StreamHandler) DownstreamChannelID() uint32 {
|
|
return sh.state.DownstreamChannelInfo.DownstreamChannelId
|
|
}
|
|
|
|
// Hostname implements StreamHandlerInterface.
|
|
func (sh *StreamHandler) Hostname() *string {
|
|
return sh.state.Hostname
|
|
}
|
|
|
|
// Username implements StreamHandlerInterface.
|
|
func (sh *StreamHandler) Username() *string {
|
|
return sh.state.Username
|
|
}
|
|
|
|
func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) {
|
|
sh.writeC <- &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
|
Deny: &extensions_ssh.DenyResponse{
|
|
Partial: partial,
|
|
Methods: sh.state.RemainingUnauthenticatedMethods,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) sendAllowResponse() {
|
|
var allow *extensions_ssh.AllowResponse
|
|
if *sh.state.Hostname == "" {
|
|
sh.expectingInternalChannel = true
|
|
allow = sh.buildInternalAllowResponse()
|
|
} else {
|
|
allow = sh.buildUpstreamAllowResponse()
|
|
}
|
|
|
|
sh.writeC <- &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
|
Allow: allow,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) {
|
|
sh.writeC <- &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
|
|
InfoRequest: &extensions_ssh.InfoRequest{
|
|
Method: MethodKeyboardInteractive,
|
|
Request: protoutil.NewAny(prompts),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse {
|
|
var allowedMethods []*extensions_ssh.AllowedMethod
|
|
if value := sh.state.PublicKeyAllow.Value; value != nil {
|
|
allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
|
|
Method: MethodPublicKey,
|
|
MethodData: protoutil.NewAny(value),
|
|
})
|
|
}
|
|
if value := sh.state.KeyboardInteractiveAllow.Value; value != nil {
|
|
allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
|
|
Method: MethodKeyboardInteractive,
|
|
MethodData: protoutil.NewAny(value),
|
|
})
|
|
}
|
|
return &extensions_ssh.AllowResponse{
|
|
Username: *sh.state.Username,
|
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
|
Upstream: &extensions_ssh.UpstreamTarget{
|
|
Hostname: *sh.state.Hostname,
|
|
DirectTcpip: sh.state.DirectTcpip,
|
|
AllowedMethods: allowedMethods,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse {
|
|
return &extensions_ssh.AllowResponse{
|
|
Username: *sh.state.Username,
|
|
Target: &extensions_ssh.AllowResponse_Internal{
|
|
Internal: &extensions_ssh.InternalTarget{
|
|
SetMetadata: &corev3.Metadata{
|
|
TypedFilterMetadata: map[string]*anypb.Any{
|
|
"com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{
|
|
StreamId: sh.downstream.StreamId,
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|