diff --git a/authenticate/handlers/webauthn/helpers.go b/authenticate/handlers/webauthn/helpers.go deleted file mode 100644 index 319e8952f..000000000 --- a/authenticate/handlers/webauthn/helpers.go +++ /dev/null @@ -1,32 +0,0 @@ -package webauthn - -import "github.com/pomerium/pomerium/pkg/grpc/session" - -func containsString(elements []string, value string) bool { - for _, element := range elements { - if element == value { - return true - } - } - return false -} - -func removeString(elements []string, value string) []string { - dup := make([]string, 0, len(elements)) - for _, element := range elements { - if element != value { - dup = append(dup, element) - } - } - return dup -} - -func removeSessionDeviceCredential(elements []*session.Session_DeviceCredential, id string) []*session.Session_DeviceCredential { - dup := make([]*session.Session_DeviceCredential, 0, len(elements)) - for _, element := range elements { - if element.GetId() != id { - dup = append(dup, element) - } - } - return dup -} diff --git a/authenticate/handlers/webauthn/webauthn.go b/authenticate/handlers/webauthn/webauthn.go index d29b69ddc..84e792cf2 100644 --- a/authenticate/handlers/webauthn/webauthn.go +++ b/authenticate/handlers/webauthn/webauthn.go @@ -324,7 +324,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state * } // save the user - u.DeviceCredentialIds = append(u.DeviceCredentialIds, deviceCredential.GetId()) + u.AddDeviceCredentialID(deviceCredential.GetId()) _, err = databroker.Put(ctx, state.Client, u) if err != nil { return err @@ -356,7 +356,7 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state } // ensure we only allow removing a device credential the user owns - if !containsString(u.GetDeviceCredentialIds(), deviceCredentialID) { + if !u.HasDeviceCredentialID(deviceCredentialID) { return errInvalidDeviceCredential } @@ -373,14 +373,14 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state } // remove the credential from the user - u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID) + u.RemoveDeviceCredentialID(deviceCredentialID) _, err = databroker.Put(ctx, state.Client, u) if err != nil { return err } // remove the credential from the session - state.Session.DeviceCredentials = removeSessionDeviceCredential(state.Session.DeviceCredentials, deviceCredentialID) + state.Session.RemoveDeviceCredentialID(deviceCredentialID) return h.saveSessionAndRedirect(w, r, state, urlutil.GetAbsoluteURL(r).ResolveReference(&url.URL{ Path: "/.pomerium", }).String()) diff --git a/pkg/grpc/session/session.go b/pkg/grpc/session/session.go index 009278d22..8dcb68e1c 100644 --- a/pkg/grpc/session/session.go +++ b/pkg/grpc/session/session.go @@ -12,6 +12,7 @@ import ( "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/slices" ) // Delete deletes a session from the databroker. @@ -78,3 +79,10 @@ func (x *Session) SetRawIDToken(rawIDToken string) { } x.IdToken.Raw = rawIDToken } + +// RemoveDeviceCredentialID removes a device credential id. +func (x *Session) RemoveDeviceCredentialID(deviceCredentialID string) { + x.DeviceCredentials = slices.Filter(x.DeviceCredentials, func(el *Session_DeviceCredential) bool { + return el.GetId() != deviceCredentialID + }) +} diff --git a/pkg/grpc/user/user.go b/pkg/grpc/user/user.go index e10969175..b3fb600fe 100644 --- a/pkg/grpc/user/user.go +++ b/pkg/grpc/user/user.go @@ -8,6 +8,7 @@ import ( "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/slices" ) // Get gets a user from the databroker. @@ -47,3 +48,18 @@ func (x *User) GetClaim(claim string) []interface{} { } return vs } + +// AddDeviceCredentialID adds a device credential id to the list of device credential ids. +func (x *User) AddDeviceCredentialID(deviceCredentialID string) { + x.DeviceCredentialIds = slices.Unique(append(x.DeviceCredentialIds, deviceCredentialID)) +} + +// HasDeviceCredentialID returns true if the user has the device credential id. +func (x *User) HasDeviceCredentialID(deviceCredentialID string) bool { + return slices.Contains(x.DeviceCredentialIds, deviceCredentialID) +} + +// RemoveDeviceCredentialID removes the device credential id from the list of device credential ids. +func (x *User) RemoveDeviceCredentialID(deviceCredentialID string) { + x.DeviceCredentialIds = slices.Remove(x.DeviceCredentialIds, deviceCredentialID) +} diff --git a/pkg/slices/slices.go b/pkg/slices/slices.go new file mode 100644 index 000000000..3a3317505 --- /dev/null +++ b/pkg/slices/slices.go @@ -0,0 +1,47 @@ +// Package slices contains functions for working with slices. +package slices + +// Contains returns true if e is in s. +func Contains[S ~[]E, E comparable](s S, e E) bool { + for _, el := range s { + if el == e { + return true + } + } + return false +} + +// Filter returns a new slice containing only those elements for which f(element) is true. +func Filter[S ~[]E, E any](s S, f func(E) bool) S { + var ns S + for _, el := range s { + if f(el) { + ns = append(ns, el) + } + } + return ns +} + +// Remove removes e from s. +func Remove[S ~[]E, E comparable](s S, e E) S { + var ns S + for _, el := range s { + if el != e { + ns = append(ns, el) + } + } + return ns +} + +// Unique returns the unique elements of s. +func Unique[S ~[]E, E comparable](s S) S { + var ns S + h := map[E]struct{}{} + for _, el := range s { + if _, ok := h[el]; !ok { + h[el] = struct{}{} + ns = append(ns, el) + } + } + return ns +}