diff --git a/internal/zero/api/api.go b/internal/zero/api/api.go index 9f62e0936..0db0a8778 100644 --- a/internal/zero/api/api.go +++ b/internal/zero/api/api.go @@ -119,3 +119,16 @@ func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.Get func (api *API) GetTelemetryConn() *grpc.ClientConn { return api.telemetryConn } + +func (api *API) ReportUsage(ctx context.Context, req cluster_api.ReportUsageRequest) error { + res, err := api.cluster.ReportUsageWithResponse(ctx, req) + if err != nil { + return err + } + + if res.StatusCode()/100 != 2 { + return fmt.Errorf("unexpected response from ReportUsage: %d", res.StatusCode()) + } + + return nil +} diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 5fc4f1250..b1e495a01 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -159,6 +159,7 @@ func (c *controller) runZeroControlLoop(ctx context.Context) error { c.runSessionAnalyticsLeased, c.runPeriodicHealthChecksLeased, leaseStatus.MonitorLease, + c.runUsageReporter, ), ) }) @@ -196,6 +197,13 @@ func (c *controller) runPeriodicHealthChecksLeased(ctx context.Context, client d }) } +func (c *controller) runUsageReporter(ctx context.Context, client databroker.DataBrokerServiceClient) error { + r := newUsageReporter(c.api) + return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error { + return r.run(ctx, client) + }) +} + func (c *controller) getEnvoyScrapeURL() string { return (&url.URL{ Scheme: "http", diff --git a/internal/zero/controller/usage_reporter.go b/internal/zero/controller/usage_reporter.go new file mode 100644 index 000000000..ed7ee02a6 --- /dev/null +++ b/internal/zero/controller/usage_reporter.go @@ -0,0 +1,162 @@ +package controller + +import ( + "context" + "sync" + "time" + + "github.com/pomerium/pomerium/internal/log" + sdk "github.com/pomerium/pomerium/internal/zero/api" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/zero/cluster" +) + +type usageReporterRecord struct { + userID string + userDisplayName string + userEmail string + accessedAt time.Time +} + +type usageReporter struct { + api *sdk.API + + mu sync.Mutex + byUserID map[string]usageReporterRecord +} + +func newUsageReporter(api *sdk.API) *usageReporter { + return &usageReporter{ + api: api, + byUserID: make(map[string]usageReporterRecord), + } +} + +func (ur *usageReporter) report(ctx context.Context, records []usageReporterRecord) error { + req := cluster.ReportUsageRequest{} + for _, record := range records { + req.Users = append(req.Users, cluster.ReportUsageUser{ + AccessedAt: record.accessedAt, + DisplayName: record.userDisplayName, + Email: record.userEmail, + Id: record.userID, + }) + } + + log.Info(ctx).Int("users", len(req.Users)).Msg("reporting usage") + + // if there were no updates there's nothing to do + if len(req.Users) == 0 { + return nil + } + + return ur.api.ReportUsage(ctx, req) +} + +func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerServiceClient) error { + timer := time.NewTicker(time.Hour) + defer timer.Stop() + + for { + err := ur.runOnce(ctx, client) + if err != nil { + log.Error(ctx).Err(err).Msg("failed to report usage") + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + } +} + +func (ur *usageReporter) runOnce(ctx context.Context, client databroker.DataBrokerServiceClient) error { + updated, err := ur.update(ctx, client) + if err != nil { + return err + } + + err = ur.report(ctx, updated) + if err != nil { + return err + } + + return nil +} + +func (ur *usageReporter) update(ctx context.Context, client databroker.DataBrokerServiceClient) ([]usageReporterRecord, error) { + updatedUserIDs := map[string]struct{}{} + + ur.mu.Lock() + defer ur.mu.Unlock() + + // delete old records + now := time.Now() + for userID, r := range ur.byUserID { + if r.accessedAt.Add(24 * time.Hour).Before(now) { + delete(ur.byUserID, userID) + } + } + + // create records for all the sessions + for s, err := range databroker.IterateAll[session.Session](ctx, client) { + if err != nil { + return nil, err + } + + userID := s.Object.GetUserId() + if userID == "" { + continue + } + + r := ur.byUserID[userID] + nr := r + nr.accessedAt = latest(nr.accessedAt, s.Object.GetIssuedAt().AsTime()) + nr.userID = userID + if r != nr { + updatedUserIDs[userID] = struct{}{} + ur.byUserID[userID] = nr + } + } + + // fill in user names and emails + for u, err := range databroker.IterateAll[user.User](ctx, client) { + if err != nil { + return nil, err + } + + userID := u.GetId() + if userID == "" { + continue + } + + r, ok := ur.byUserID[userID] + if !ok { + // ignore sessionless users + continue + } + nr := r + nr.userDisplayName = u.Object.GetName() + nr.userEmail = u.Object.GetEmail() + if r != nr { + updatedUserIDs[userID] = struct{}{} + ur.byUserID[userID] = nr + } + } + + var updated []usageReporterRecord + for key := range updatedUserIDs { + updated = append(updated, ur.byUserID[key]) + } + return updated, nil +} + +func latest(t1, t2 time.Time) time.Time { + if t2.After(t1) { + return t2 + } + return t1 +} diff --git a/pkg/grpc/databroker/databroker.go b/pkg/grpc/databroker/databroker.go index 26841ca7f..d651e868e 100644 --- a/pkg/grpc/databroker/databroker.go +++ b/pkg/grpc/databroker/databroker.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "iter" "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" @@ -14,6 +15,7 @@ import ( "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/protoutil" ) @@ -139,6 +141,62 @@ loop: return records, recordVersion, serverVersion, nil } +// IterateAll iterates through all the records using a SyncLatest call. +func IterateAll[T any, TMessage interface { + *T + proto.Message +}]( + ctx context.Context, + client DataBrokerServiceClient, +) iter.Seq2[GenericRecord[TMessage], error] { + var zero GenericRecord[TMessage] + return func(yield func(GenericRecord[TMessage], error) bool) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var msg any = new(T) + stream, err := client.SyncLatest(ctx, &SyncLatestRequest{ + Type: protoutil.GetTypeURL(msg.(TMessage)), + }) + if err != nil { + _ = yield(zero, err) + return + } + + for { + res, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + // all done + return + case err != nil: + _ = yield(zero, err) + return + } + + switch res := res.GetResponse().(type) { + case *SyncLatestResponse_Versions: + // ignore + case *SyncLatestResponse_Record: + fmt.Println("RECORD", res.Record) + gr := GenericRecord[TMessage]{ + Record: res.Record, + } + var msg any = new(T) + gr.Object = msg.(TMessage) + err = res.Record.GetData().UnmarshalTo(gr.Object) + if err != nil { + log.Error(ctx).Err(err).Msg("databroker: unexpected object found in databroker record") + } else if !yield(gr, nil) { + return + } + default: + panic(fmt.Sprintf("unexpected response: %T", res)) + } + } + } +} + // GetRecord gets the first record, or nil if there are none. func (x *PutRequest) GetRecord() *Record { records := x.GetRecords() diff --git a/pkg/grpc/databroker/generic.go b/pkg/grpc/databroker/generic.go new file mode 100644 index 000000000..1d6068459 --- /dev/null +++ b/pkg/grpc/databroker/generic.go @@ -0,0 +1,10 @@ +package databroker + +import ( + "google.golang.org/protobuf/proto" +) + +type GenericRecord[T proto.Message] struct { + *Record + Object T +} diff --git a/pkg/zero/cluster/client.gen.go b/pkg/zero/cluster/client.gen.go index 0f179fd97..44c5afc62 100644 --- a/pkg/zero/cluster/client.gen.go +++ b/pkg/zero/cluster/client.gen.go @@ -107,6 +107,11 @@ type ClientInterface interface { ExchangeClusterIdentityTokenWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) ExchangeClusterIdentityToken(ctx context.Context, body ExchangeClusterIdentityTokenJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // ReportUsageWithBody request with any body + ReportUsageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + ReportUsage(ctx context.Context, body ReportUsageJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) } func (c *Client) GetClusterBootstrapConfig(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { @@ -193,6 +198,30 @@ func (c *Client) ExchangeClusterIdentityToken(ctx context.Context, body Exchange return c.Client.Do(req) } +func (c *Client) ReportUsageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewReportUsageRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) ReportUsage(ctx context.Context, body ReportUsageJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewReportUsageRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + // NewGetClusterBootstrapConfigRequest generates requests for GetClusterBootstrapConfig func NewGetClusterBootstrapConfigRequest(server string) (*http.Request, error) { var err error @@ -368,6 +397,46 @@ func NewExchangeClusterIdentityTokenRequestWithBody(server string, contentType s return req, nil } +// NewReportUsageRequest calls the generic ReportUsage builder with application/json body +func NewReportUsageRequest(server string, body ReportUsageJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewReportUsageRequestWithBody(server, "application/json", bodyReader) +} + +// NewReportUsageRequestWithBody generates requests for ReportUsage with any type of body +func NewReportUsageRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/reportUsage") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + func (c *Client) applyEditors(ctx context.Context, req *http.Request, additionalEditors []RequestEditorFn) error { for _, r := range c.RequestEditors { if err := r(ctx, req); err != nil { @@ -429,6 +498,11 @@ type ClientWithResponsesInterface interface { ExchangeClusterIdentityTokenWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ExchangeClusterIdentityTokenResp, error) ExchangeClusterIdentityTokenWithResponse(ctx context.Context, body ExchangeClusterIdentityTokenJSONRequestBody, reqEditors ...RequestEditorFn) (*ExchangeClusterIdentityTokenResp, error) + + // ReportUsageWithBodyWithResponse request with any body + ReportUsageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ReportUsageResp, error) + + ReportUsageWithResponse(ctx context.Context, body ReportUsageJSONRequestBody, reqEditors ...RequestEditorFn) (*ReportUsageResp, error) } type GetClusterBootstrapConfigResp struct { @@ -551,6 +625,27 @@ func (r ExchangeClusterIdentityTokenResp) StatusCode() int { return 0 } +type ReportUsageResp struct { + Body []byte + HTTPResponse *http.Response +} + +// Status returns HTTPResponse.Status +func (r ReportUsageResp) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r ReportUsageResp) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + // GetClusterBootstrapConfigWithResponse request returning *GetClusterBootstrapConfigResp func (c *ClientWithResponses) GetClusterBootstrapConfigWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetClusterBootstrapConfigResp, error) { rsp, err := c.GetClusterBootstrapConfig(ctx, reqEditors...) @@ -612,6 +707,23 @@ func (c *ClientWithResponses) ExchangeClusterIdentityTokenWithResponse(ctx conte return ParseExchangeClusterIdentityTokenResp(rsp) } +// ReportUsageWithBodyWithResponse request with arbitrary body returning *ReportUsageResp +func (c *ClientWithResponses) ReportUsageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ReportUsageResp, error) { + rsp, err := c.ReportUsageWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseReportUsageResp(rsp) +} + +func (c *ClientWithResponses) ReportUsageWithResponse(ctx context.Context, body ReportUsageJSONRequestBody, reqEditors ...RequestEditorFn) (*ReportUsageResp, error) { + rsp, err := c.ReportUsage(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseReportUsageResp(rsp) +} + // ParseGetClusterBootstrapConfigResp parses an HTTP response from a GetClusterBootstrapConfigWithResponse call func ParseGetClusterBootstrapConfigResp(rsp *http.Response) (*GetClusterBootstrapConfigResp, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -811,3 +923,19 @@ func ParseExchangeClusterIdentityTokenResp(rsp *http.Response) (*ExchangeCluster return response, nil } + +// ParseReportUsageResp parses an HTTP response from a ReportUsageWithResponse call +func ParseReportUsageResp(rsp *http.Response) (*ReportUsageResp, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &ReportUsageResp{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + return response, nil +} diff --git a/pkg/zero/cluster/models.gen.go b/pkg/zero/cluster/models.gen.go index b3716332c..fb8413431 100644 --- a/pkg/zero/cluster/models.gen.go +++ b/pkg/zero/cluster/models.gen.go @@ -3,6 +3,10 @@ // Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.3.0 DO NOT EDIT. package cluster +import ( + "time" +) + const ( BearerAuthScopes = "bearerAuth.Scopes" ) @@ -92,6 +96,19 @@ type GetBundlesResponse struct { Bundles []Bundle `json:"bundles"` } +// ReportUsageRequest defines model for ReportUsageRequest. +type ReportUsageRequest struct { + Users []ReportUsageUser `json:"users"` +} + +// ReportUsageUser defines model for ReportUsageUser. +type ReportUsageUser struct { + AccessedAt time.Time `json:"accessedAt"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + Id string `json:"id"` +} + // BundleId defines model for bundleId. type BundleId = string @@ -100,3 +117,6 @@ type ReportClusterResourceBundleStatusJSONRequestBody = BundleStatus // ExchangeClusterIdentityTokenJSONRequestBody defines body for ExchangeClusterIdentityToken for application/json ContentType. type ExchangeClusterIdentityTokenJSONRequestBody = ExchangeTokenRequest + +// ReportUsageJSONRequestBody defines body for ReportUsage for application/json ContentType. +type ReportUsageJSONRequestBody = ReportUsageRequest diff --git a/pkg/zero/cluster/openapi.yaml b/pkg/zero/cluster/openapi.yaml index e9c43bf71..e8cdf719c 100644 --- a/pkg/zero/cluster/openapi.yaml +++ b/pkg/zero/cluster/openapi.yaml @@ -148,6 +148,20 @@ paths: application/json: schema: $ref: "#/components/schemas/ErrorResponse" + /reportUsage: + post: + description: Report usage for the cluster + operationId: reportUsage + tags: [usage] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ReportUsageRequest" + responses: + "204": + description: OK components: parameters: @@ -273,3 +287,29 @@ components: $ref: "#/components/schemas/Bundle" required: - bundles + ReportUsageRequest: + type: object + properties: + users: + type: array + items: + $ref: "#/components/schemas/ReportUsageUser" + required: + - users + ReportUsageUser: + type: object + properties: + accessedAt: + type: string + format: "date-time" + displayName: + type: string + email: + type: string + id: + type: string + required: + - accessedAt + - displayName + - email + - id diff --git a/pkg/zero/cluster/server.gen.go b/pkg/zero/cluster/server.gen.go index d12303d69..d8cd60eec 100644 --- a/pkg/zero/cluster/server.gen.go +++ b/pkg/zero/cluster/server.gen.go @@ -31,6 +31,9 @@ type ServerInterface interface { // (POST /exchangeToken) ExchangeClusterIdentityToken(w http.ResponseWriter, r *http.Request) + + // (POST /reportUsage) + ReportUsage(w http.ResponseWriter, r *http.Request) } // Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint. @@ -62,6 +65,11 @@ func (_ Unimplemented) ExchangeClusterIdentityToken(w http.ResponseWriter, r *ht w.WriteHeader(http.StatusNotImplemented) } +// (POST /reportUsage) +func (_ Unimplemented) ReportUsage(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // ServerInterfaceWrapper converts contexts to parameters. type ServerInterfaceWrapper struct { Handler ServerInterface @@ -176,6 +184,23 @@ func (siw *ServerInterfaceWrapper) ExchangeClusterIdentityToken(w http.ResponseW handler.ServeHTTP(w, r.WithContext(ctx)) } +// ReportUsage operation middleware +func (siw *ServerInterfaceWrapper) ReportUsage(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, BearerAuthScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.ReportUsage(w, r) + })) + + for i := len(siw.HandlerMiddlewares) - 1; i >= 0; i-- { + handler = siw.HandlerMiddlewares[i](handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + type UnescapedCookieParamError struct { ParamName string Err error @@ -304,6 +329,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/exchangeToken", wrapper.ExchangeClusterIdentityToken) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/reportUsage", wrapper.ReportUsage) + }) return r } @@ -490,6 +518,22 @@ func (response ExchangeClusterIdentityToken500JSONResponse) VisitExchangeCluster return json.NewEncoder(w).Encode(response) } +type ReportUsageRequestObject struct { + Body *ReportUsageJSONRequestBody +} + +type ReportUsageResponseObject interface { + VisitReportUsageResponse(w http.ResponseWriter) error +} + +type ReportUsage204Response struct { +} + +func (response ReportUsage204Response) VisitReportUsageResponse(w http.ResponseWriter) error { + w.WriteHeader(204) + return nil +} + // StrictServerInterface represents all server handlers. type StrictServerInterface interface { @@ -507,6 +551,9 @@ type StrictServerInterface interface { // (POST /exchangeToken) ExchangeClusterIdentityToken(ctx context.Context, request ExchangeClusterIdentityTokenRequestObject) (ExchangeClusterIdentityTokenResponseObject, error) + + // (POST /reportUsage) + ReportUsage(ctx context.Context, request ReportUsageRequestObject) (ReportUsageResponseObject, error) } type StrictHandlerFunc = strictnethttp.StrictHTTPHandlerFunc @@ -675,3 +722,34 @@ func (sh *strictHandler) ExchangeClusterIdentityToken(w http.ResponseWriter, r * sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response)) } } + +// ReportUsage operation middleware +func (sh *strictHandler) ReportUsage(w http.ResponseWriter, r *http.Request) { + var request ReportUsageRequestObject + + var body ReportUsageJSONRequestBody + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + sh.options.RequestErrorHandlerFunc(w, r, fmt.Errorf("can't decode JSON body: %w", err)) + return + } + request.Body = &body + + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) { + return sh.ssi.ReportUsage(ctx, request.(ReportUsageRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "ReportUsage") + } + + response, err := handler(r.Context(), w, r, request) + + if err != nil { + sh.options.ResponseErrorHandlerFunc(w, r, err) + } else if validResponse, ok := response.(ReportUsageResponseObject); ok { + if err := validResponse.VisitReportUsageResponse(w); err != nil { + sh.options.ResponseErrorHandlerFunc(w, r, err) + } + } else if response != nil { + sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response)) + } +}