pomerium/authorize/ssh_grpc.go
Kenneth Jenkins 9678e6a231
ssh: implement authorization policy evaluation (#5665)
Implement the pkg/ssh.AuthInterface. Add logic for converting from the
ssh stream state to an evaluator request, and for interpreting the
results of policy evaluation. Refactor some of the existing authorize
logic to make it easier to reuse.
2025-07-01 12:04:00 -07:00

133 lines
3.2 KiB
Go

package authorize
import (
"context"
"errors"
"io"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/ssh"
)
func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageStreamServer) error {
event, err := stream.Recv()
if err != nil {
return err
}
// first message should be a downstream connected event
downstream := event.GetEvent().GetDownstreamConnected()
if downstream == nil {
return status.Errorf(codes.Internal, "first message was not a downstream connected event")
}
state := a.state.Load()
handler := state.ssh.NewStreamHandler(
a.currentConfig.Load(),
ssh.NewAuth(a, state.dataBrokerClient, a.currentConfig, a.tracerProvider),
downstream,
)
defer handler.Close()
eg, ctx := errgroup.WithContext(stream.Context())
eg.Go(func() error {
for {
req, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
handler.ReadC() <- req
}
})
eg.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case msg := <-handler.WriteC():
if err := stream.Send(msg); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
}
}
})
return handler.Run(ctx)
}
func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error {
metadata, err := stream.Recv()
if err != nil {
return err
}
// first message contains metadata
var streamID uint64
if md := metadata.GetMetadata(); md != nil {
var typedMd extensions_ssh.FilterMetadata
if err := md.GetTypedFilterMetadata()["com.pomerium.ssh"].UnmarshalTo(&typedMd); err != nil {
return err
}
streamID = typedMd.StreamId
} else {
return status.Errorf(codes.Internal, "first message was not metadata")
}
handler := a.state.Load().ssh.LookupStream(streamID)
if handler == nil || !handler.IsExpectingInternalChannel() {
return status.Errorf(codes.InvalidArgument, "stream not found")
}
return handler.ServeChannel(stream)
}
func (a *Authorize) EvaluateSSH(ctx context.Context, req *ssh.Request) (*evaluator.Result, error) {
ctx = a.withQuerierForCheckRequest(ctx)
evalreq := evaluator.Request{
HTTP: evaluator.RequestHTTP{
Hostname: req.Hostname,
},
SSH: evaluator.RequestSSH{
Username: req.Username,
PublicKey: req.PublicKey,
},
Session: evaluator.RequestSession{
ID: req.SessionID,
},
}
if req.Hostname == "" {
evalreq.IsInternal = true
} else {
evalreq.Policy = a.currentConfig.Load().Options.GetRouteForSSHHostname(req.Hostname)
}
res, err := a.state.Load().evaluator.Evaluate(ctx, &evalreq)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("error during OPA evaluation")
return nil, err
}
s, _ := a.getDataBrokerSessionOrServiceAccount(ctx, req.SessionID, 0)
var u *user.User
if s != nil {
u, _ = a.getDataBrokerUser(ctx, s.GetUserId())
}
a.logAuthorizeCheck(ctx, &evalreq, res, s, u)
return res, nil
}