mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
Re-evaluate ssh authorization decision on a fixed interval, or whenever the config changes. If access is no longer allowed, log a new 'authorize check' message and disconnect. Refactor the ssh.StreamManager initialization so that its lifecycle matches the Authorize lifecycle.
519 lines
17 KiB
Go
519 lines
17 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"iter"
|
|
"time"
|
|
|
|
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
|
|
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
"github.com/pomerium/pomerium/pkg/slices"
|
|
)
|
|
|
|
const (
|
|
MethodPublicKey = "publickey"
|
|
MethodKeyboardInteractive = "keyboard-interactive"
|
|
)
|
|
|
|
type KeyboardInteractiveQuerier interface {
|
|
// Prompts the client and returns their responses to the given prompts.
|
|
Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error)
|
|
}
|
|
|
|
type AuthMethodResponse[T any] struct {
|
|
Allow *T
|
|
RequireAdditionalMethods []string
|
|
}
|
|
|
|
type (
|
|
PublicKeyAuthMethodResponse = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse]
|
|
KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
|
|
)
|
|
|
|
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
|
|
|
|
type AuthInterface interface {
|
|
HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
|
|
HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
|
|
EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error
|
|
FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error)
|
|
DeleteSession(ctx context.Context, info StreamAuthInfo) error
|
|
}
|
|
|
|
type AuthMethodValue[T any] struct {
|
|
attempted bool
|
|
Value *T
|
|
}
|
|
|
|
func (v *AuthMethodValue[T]) Update(value *T) {
|
|
v.attempted = true
|
|
v.Value = value
|
|
}
|
|
|
|
func (v *AuthMethodValue[T]) IsValid() bool {
|
|
if v.attempted {
|
|
// method was attempted - valid iff there is a value
|
|
return v.Value != nil
|
|
}
|
|
return true // method was not attempted - valid
|
|
}
|
|
|
|
type StreamAuthInfo struct {
|
|
Username *string
|
|
Hostname *string
|
|
StreamID uint64
|
|
SourceAddress string
|
|
PublicKeyFingerprintSha256 []byte
|
|
PublicKeyAllow AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
|
|
KeyboardInteractiveAllow AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
|
|
InitialAuthComplete bool
|
|
}
|
|
|
|
func (i *StreamAuthInfo) allMethodsValid() bool {
|
|
return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid()
|
|
}
|
|
|
|
type StreamState struct {
|
|
StreamAuthInfo
|
|
DirectTcpip bool
|
|
RemainingUnauthenticatedMethods []string
|
|
DownstreamChannelInfo *extensions_ssh.SSHDownstreamChannelInfo
|
|
}
|
|
|
|
// StreamHandler handles a single SSH stream
|
|
type StreamHandler struct {
|
|
auth AuthInterface
|
|
config *config.Config
|
|
downstream *extensions_ssh.DownstreamConnectEvent
|
|
writeC chan *extensions_ssh.ServerMessage
|
|
readC chan *extensions_ssh.ClientMessage
|
|
reauthC chan struct{}
|
|
|
|
state *StreamState
|
|
close func()
|
|
|
|
channelIDCounter uint32
|
|
expectingInternalChannel bool
|
|
}
|
|
|
|
var _ StreamHandlerInterface = (*StreamHandler)(nil)
|
|
|
|
func (sh *StreamHandler) Close() {
|
|
sh.close()
|
|
}
|
|
|
|
func (sh *StreamHandler) IsExpectingInternalChannel() bool {
|
|
return sh.expectingInternalChannel
|
|
}
|
|
|
|
func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage {
|
|
return sh.readC
|
|
}
|
|
|
|
func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
|
|
return sh.writeC
|
|
}
|
|
|
|
// Reauth blocks until authorization policy is reevaluated.
|
|
func (sh *StreamHandler) Reauth() {
|
|
sh.reauthC <- struct{}{}
|
|
}
|
|
|
|
func (sh *StreamHandler) periodicReauth() (cancel func()) {
|
|
t := time.NewTicker(1 * time.Minute)
|
|
go func() {
|
|
for range t.C {
|
|
sh.Reauth()
|
|
}
|
|
}()
|
|
return t.Stop
|
|
}
|
|
|
|
// Prompt implements KeyboardInteractiveQuerier.
|
|
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
|
|
sh.sendInfoPrompts(prompts)
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, context.Cause(ctx)
|
|
case req := <-sh.readC:
|
|
switch msg := req.Message.(type) {
|
|
case *extensions_ssh.ClientMessage_InfoResponse:
|
|
if msg.InfoResponse.Method != "keyboard-interactive" {
|
|
return nil, status.Errorf(codes.Internal, "received invalid info response")
|
|
}
|
|
r, _ := msg.InfoResponse.Response.UnmarshalNew()
|
|
respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
|
|
if !ok {
|
|
return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response")
|
|
}
|
|
return respInfo, nil
|
|
default:
|
|
return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) Run(ctx context.Context) error {
|
|
if sh.state != nil {
|
|
panic("Run called twice")
|
|
}
|
|
sh.state = &StreamState{
|
|
RemainingUnauthenticatedMethods: []string{MethodPublicKey},
|
|
StreamAuthInfo: StreamAuthInfo{
|
|
StreamID: sh.downstream.StreamId,
|
|
SourceAddress: sh.downstream.SourceAddress,
|
|
},
|
|
}
|
|
cancelReauth := sh.periodicReauth()
|
|
defer cancelReauth()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return context.Cause(ctx)
|
|
case <-sh.reauthC:
|
|
if err := sh.reauth(ctx); err != nil {
|
|
return err
|
|
}
|
|
case req := <-sh.readC:
|
|
switch req := req.Message.(type) {
|
|
case *extensions_ssh.ClientMessage_Event:
|
|
switch event := req.Event.Event.(type) {
|
|
case *extensions_ssh.StreamEvent_DownstreamConnected:
|
|
// this was already received as the first message in the stream
|
|
return status.Errorf(codes.Internal, "received duplicate downstream connected event")
|
|
case *extensions_ssh.StreamEvent_UpstreamConnected:
|
|
log.Ctx(ctx).Debug().
|
|
Uint64("stream-id", event.UpstreamConnected.StreamId).
|
|
Msg("ssh: upstream connected")
|
|
case *extensions_ssh.StreamEvent_DownstreamDisconnected:
|
|
log.Ctx(ctx).Debug().
|
|
Uint64("stream-id", sh.downstream.StreamId).
|
|
Str("reason", event.DownstreamDisconnected.Reason).
|
|
Msg("ssh: downstream disconnected")
|
|
case nil:
|
|
return status.Errorf(codes.Internal, "received invalid event")
|
|
}
|
|
case *extensions_ssh.ClientMessage_AuthRequest:
|
|
if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return status.Errorf(codes.Internal, "received invalid message")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error {
|
|
// The first channel message on this stream should be a ChannelOpen
|
|
channelOpen, err := stream.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes)
|
|
if !ok {
|
|
return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
|
|
}
|
|
var msg ChannelOpenMsg
|
|
if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil {
|
|
return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
|
|
}
|
|
|
|
sh.channelIDCounter++
|
|
sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
|
|
ChannelType: msg.ChanType,
|
|
DownstreamChannelId: msg.PeersID,
|
|
InternalUpstreamChannelId: sh.channelIDCounter,
|
|
InitialWindowSize: msg.PeersWindow,
|
|
MaxPacketSize: msg.MaxPacketSize,
|
|
}
|
|
channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
|
|
switch msg.ChanType {
|
|
case "session":
|
|
if err := channel.SendMessage(ChannelOpenConfirmMsg{
|
|
PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
|
|
MyID: sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
|
|
MyWindow: ChannelWindowSize,
|
|
MaxPacketSize: ChannelMaxPacket,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
ch := NewChannelHandler(channel, sh.config)
|
|
return ch.Run(stream.Context())
|
|
case "direct-tcpip":
|
|
var subMsg ChannelOpenDirectMsg
|
|
if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
|
|
return err
|
|
}
|
|
sh.state.DirectTcpip = true
|
|
action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return channel.SendControlAction(action)
|
|
default:
|
|
return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType)
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error {
|
|
if req.Protocol != "ssh" {
|
|
return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
|
|
}
|
|
if req.Service != "ssh-connection" {
|
|
return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
|
|
}
|
|
if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
|
|
return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod)
|
|
}
|
|
|
|
if sh.state.Username == nil {
|
|
if req.Username == "" {
|
|
return status.Errorf(codes.InvalidArgument, "username missing")
|
|
}
|
|
sh.state.Username = &req.Username
|
|
} else if *sh.state.Username != req.Username {
|
|
return status.Errorf(codes.InvalidArgument, "inconsistent username")
|
|
}
|
|
if sh.state.Hostname == nil {
|
|
sh.state.Hostname = &req.Hostname
|
|
} else if *sh.state.Hostname != req.Hostname {
|
|
return status.Errorf(codes.InvalidArgument, "inconsistent hostname")
|
|
}
|
|
|
|
updateMethods := func(add []string) {
|
|
sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod)
|
|
sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...)
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Str("method", req.AuthMethod).
|
|
Str("username", *sh.state.Username).
|
|
Str("hostname", *sh.state.Hostname).
|
|
Msg("ssh: handling auth request")
|
|
|
|
var partial bool
|
|
switch req.AuthMethod {
|
|
case MethodPublicKey:
|
|
methodReq, _ := req.MethodRequest.UnmarshalNew()
|
|
pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
|
|
if !ok {
|
|
return status.Errorf(codes.InvalidArgument, "invalid public key method request type")
|
|
}
|
|
response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
|
|
if err != nil {
|
|
return err
|
|
} else if response.Allow != nil {
|
|
partial = true
|
|
sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
|
|
}
|
|
sh.state.PublicKeyAllow.Update(response.Allow)
|
|
updateMethods(response.RequireAdditionalMethods)
|
|
case MethodKeyboardInteractive:
|
|
methodReq, _ := req.MethodRequest.UnmarshalNew()
|
|
kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest)
|
|
if !ok {
|
|
return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type")
|
|
}
|
|
response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partial = response.Allow != nil
|
|
sh.state.KeyboardInteractiveAllow.Update(response.Allow)
|
|
updateMethods(response.RequireAdditionalMethods)
|
|
default:
|
|
return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod)
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Str("method", req.AuthMethod).
|
|
Bool("partial", partial).
|
|
Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods).
|
|
Msg("ssh: auth request complete")
|
|
|
|
if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
|
|
// if there are no methods remaining, the user is allowed if all attempted
|
|
// methods have a valid response in the state
|
|
sh.state.InitialAuthComplete = true
|
|
log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
|
|
sh.sendAllowResponse()
|
|
} else {
|
|
log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response")
|
|
sh.sendDenyResponseWithRemainingMethods(partial)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (sh *StreamHandler) reauth(ctx context.Context) error {
|
|
if !sh.state.InitialAuthComplete {
|
|
return nil
|
|
}
|
|
return sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
|
|
}
|
|
|
|
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
|
|
if hostname == "" {
|
|
return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
|
|
}
|
|
if sh.state.Hostname == nil {
|
|
panic("bug: PrepareHandoff called but state is missing a hostname")
|
|
}
|
|
if *sh.state.Hostname != "" {
|
|
panic("bug: PrepareHandoff called but previous hostname is not empty")
|
|
}
|
|
*sh.state.Hostname = hostname
|
|
err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
|
|
if err != nil {
|
|
return nil, status.Error(codes.PermissionDenied, err.Error())
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Str("hostname", *sh.state.Hostname).
|
|
Str("username", *sh.state.Username).
|
|
Msg("ssh: initiating handoff to upstream")
|
|
upstreamAllow := sh.buildUpstreamAllowResponse()
|
|
action := &extensions_ssh.SSHChannelControlAction{
|
|
Action: &extensions_ssh.SSHChannelControlAction_HandOff{
|
|
HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
|
|
DownstreamChannelInfo: sh.state.DownstreamChannelInfo,
|
|
DownstreamPtyInfo: ptyInfo,
|
|
UpstreamAuth: upstreamAllow,
|
|
},
|
|
},
|
|
}
|
|
return action, nil
|
|
}
|
|
|
|
func (sh *StreamHandler) FormatSession(ctx context.Context) ([]byte, error) {
|
|
return sh.auth.FormatSession(ctx, sh.state.StreamAuthInfo)
|
|
}
|
|
|
|
func (sh *StreamHandler) DeleteSession(ctx context.Context) error {
|
|
return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo)
|
|
}
|
|
|
|
func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] {
|
|
return func(yield func(*config.Policy) bool) {
|
|
for route := range sh.config.Options.GetAllPolicies() {
|
|
if route.IsSSH() {
|
|
if !yield(route) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// DownstreamChannelID implements StreamHandlerInterface.
|
|
func (sh *StreamHandler) DownstreamChannelID() uint32 {
|
|
return sh.state.DownstreamChannelInfo.DownstreamChannelId
|
|
}
|
|
|
|
// Hostname implements StreamHandlerInterface.
|
|
func (sh *StreamHandler) Hostname() *string {
|
|
return sh.state.Hostname
|
|
}
|
|
|
|
// Username implements StreamHandlerInterface.
|
|
func (sh *StreamHandler) Username() *string {
|
|
return sh.state.Username
|
|
}
|
|
|
|
func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) {
|
|
sh.writeC <- &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Deny{
|
|
Deny: &extensions_ssh.DenyResponse{
|
|
Partial: partial,
|
|
Methods: sh.state.RemainingUnauthenticatedMethods,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) sendAllowResponse() {
|
|
var allow *extensions_ssh.AllowResponse
|
|
if *sh.state.Hostname == "" {
|
|
sh.expectingInternalChannel = true
|
|
allow = sh.buildInternalAllowResponse()
|
|
} else {
|
|
allow = sh.buildUpstreamAllowResponse()
|
|
}
|
|
|
|
sh.writeC <- &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_Allow{
|
|
Allow: allow,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) {
|
|
sh.writeC <- &extensions_ssh.ServerMessage{
|
|
Message: &extensions_ssh.ServerMessage_AuthResponse{
|
|
AuthResponse: &extensions_ssh.AuthenticationResponse{
|
|
Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
|
|
InfoRequest: &extensions_ssh.InfoRequest{
|
|
Method: MethodKeyboardInteractive,
|
|
Request: protoutil.NewAny(prompts),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse {
|
|
var allowedMethods []*extensions_ssh.AllowedMethod
|
|
if value := sh.state.PublicKeyAllow.Value; value != nil {
|
|
allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
|
|
Method: MethodPublicKey,
|
|
MethodData: protoutil.NewAny(value),
|
|
})
|
|
}
|
|
if value := sh.state.KeyboardInteractiveAllow.Value; value != nil {
|
|
allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
|
|
Method: MethodKeyboardInteractive,
|
|
MethodData: protoutil.NewAny(value),
|
|
})
|
|
}
|
|
return &extensions_ssh.AllowResponse{
|
|
Username: *sh.state.Username,
|
|
Target: &extensions_ssh.AllowResponse_Upstream{
|
|
Upstream: &extensions_ssh.UpstreamTarget{
|
|
Hostname: *sh.state.Hostname,
|
|
DirectTcpip: sh.state.DirectTcpip,
|
|
AllowedMethods: allowedMethods,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse {
|
|
return &extensions_ssh.AllowResponse{
|
|
Username: *sh.state.Username,
|
|
Target: &extensions_ssh.AllowResponse_Internal{
|
|
Internal: &extensions_ssh.InternalTarget{
|
|
SetMetadata: &corev3.Metadata{
|
|
TypedFilterMetadata: map[string]*anypb.Any{
|
|
"com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{
|
|
StreamId: sh.downstream.StreamId,
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|