pomerium/authorize/ssh_grpc.go
Joe Kralicky b216b7a135
ssh: stream management api (#5670)
## Summary

This implements the StreamManagement API defined at 

https://github.com/pomerium/envoy-custom/blob/main/api/extensions/filters/network/ssh/ssh.proto#L46-L60.
Policy evaluation and authorization logic is stubbed out here, and
implemented in https://github.com/pomerium/pomerium/pull/5665.

## Related issues

<!-- For example...
- #159
-->

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [ ] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
2025-07-01 13:57:19 -04:00

89 lines
2.2 KiB
Go

package authorize
import (
"errors"
"io"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
"github.com/pomerium/pomerium/pkg/storage"
)
func (a *Authorize) ManageStream(stream extensions_ssh.StreamManagement_ManageStreamServer) error {
event, err := stream.Recv()
if err != nil {
return err
}
// first message should be a downstream connected event
downstream := event.GetEvent().GetDownstreamConnected()
if downstream == nil {
return status.Errorf(codes.Internal, "first message was not a downstream connected event")
}
handler := a.state.Load().ssh.NewStreamHandler(a.currentConfig.Load(), downstream)
defer handler.Close()
eg, ctx := errgroup.WithContext(stream.Context())
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.GlobalCache,
)
ctx = storage.WithQuerier(ctx, querier)
eg.Go(func() error {
for {
req, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
handler.ReadC() <- req
}
})
eg.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case msg := <-handler.WriteC():
if err := stream.Send(msg); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
}
}
})
return handler.Run(ctx)
}
func (a *Authorize) ServeChannel(stream extensions_ssh.StreamManagement_ServeChannelServer) error {
metadata, err := stream.Recv()
if err != nil {
return err
}
// first message contains metadata
var streamID uint64
if md := metadata.GetMetadata(); md != nil {
var typedMd extensions_ssh.FilterMetadata
if err := md.GetTypedFilterMetadata()["com.pomerium.ssh"].UnmarshalTo(&typedMd); err != nil {
return err
}
streamID = typedMd.StreamId
} else {
return status.Errorf(codes.Internal, "first message was not metadata")
}
handler := a.state.Load().ssh.LookupStream(streamID)
if handler == nil || !handler.IsExpectingInternalChannel() {
return status.Errorf(codes.InvalidArgument, "stream not found")
}
return handler.ServeChannel(stream)
}