device: add generic methods for working with user+session devices (#3710)

This commit is contained in:
Caleb Doxsey 2022-10-28 08:41:12 -06:00 committed by GitHub
parent 6a9d6e45e1
commit 3f9dfbef76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 36 deletions

View file

@ -1,32 +0,0 @@
package webauthn
import "github.com/pomerium/pomerium/pkg/grpc/session"
func containsString(elements []string, value string) bool {
for _, element := range elements {
if element == value {
return true
}
}
return false
}
func removeString(elements []string, value string) []string {
dup := make([]string, 0, len(elements))
for _, element := range elements {
if element != value {
dup = append(dup, element)
}
}
return dup
}
func removeSessionDeviceCredential(elements []*session.Session_DeviceCredential, id string) []*session.Session_DeviceCredential {
dup := make([]*session.Session_DeviceCredential, 0, len(elements))
for _, element := range elements {
if element.GetId() != id {
dup = append(dup, element)
}
}
return dup
}

View file

@ -324,7 +324,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
}
// save the user
u.DeviceCredentialIds = append(u.DeviceCredentialIds, deviceCredential.GetId())
u.AddDeviceCredentialID(deviceCredential.GetId())
_, err = databroker.Put(ctx, state.Client, u)
if err != nil {
return err
@ -356,7 +356,7 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
}
// ensure we only allow removing a device credential the user owns
if !containsString(u.GetDeviceCredentialIds(), deviceCredentialID) {
if !u.HasDeviceCredentialID(deviceCredentialID) {
return errInvalidDeviceCredential
}
@ -373,14 +373,14 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
}
// remove the credential from the user
u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID)
u.RemoveDeviceCredentialID(deviceCredentialID)
_, err = databroker.Put(ctx, state.Client, u)
if err != nil {
return err
}
// remove the credential from the session
state.Session.DeviceCredentials = removeSessionDeviceCredential(state.Session.DeviceCredentials, deviceCredentialID)
state.Session.RemoveDeviceCredentialID(deviceCredentialID)
return h.saveSessionAndRedirect(w, r, state, urlutil.GetAbsoluteURL(r).ResolveReference(&url.URL{
Path: "/.pomerium",
}).String())

View file

@ -12,6 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/slices"
)
// Delete deletes a session from the databroker.
@ -78,3 +79,10 @@ func (x *Session) SetRawIDToken(rawIDToken string) {
}
x.IdToken.Raw = rawIDToken
}
// RemoveDeviceCredentialID removes a device credential id.
func (x *Session) RemoveDeviceCredentialID(deviceCredentialID string) {
x.DeviceCredentials = slices.Filter(x.DeviceCredentials, func(el *Session_DeviceCredential) bool {
return el.GetId() != deviceCredentialID
})
}

View file

@ -8,6 +8,7 @@ import (
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/slices"
)
// Get gets a user from the databroker.
@ -47,3 +48,18 @@ func (x *User) GetClaim(claim string) []interface{} {
}
return vs
}
// AddDeviceCredentialID adds a device credential id to the list of device credential ids.
func (x *User) AddDeviceCredentialID(deviceCredentialID string) {
x.DeviceCredentialIds = slices.Unique(append(x.DeviceCredentialIds, deviceCredentialID))
}
// HasDeviceCredentialID returns true if the user has the device credential id.
func (x *User) HasDeviceCredentialID(deviceCredentialID string) bool {
return slices.Contains(x.DeviceCredentialIds, deviceCredentialID)
}
// RemoveDeviceCredentialID removes the device credential id from the list of device credential ids.
func (x *User) RemoveDeviceCredentialID(deviceCredentialID string) {
x.DeviceCredentialIds = slices.Remove(x.DeviceCredentialIds, deviceCredentialID)
}

47
pkg/slices/slices.go Normal file
View file

@ -0,0 +1,47 @@
// Package slices contains functions for working with slices.
package slices
// Contains returns true if e is in s.
func Contains[S ~[]E, E comparable](s S, e E) bool {
for _, el := range s {
if el == e {
return true
}
}
return false
}
// Filter returns a new slice containing only those elements for which f(element) is true.
func Filter[S ~[]E, E any](s S, f func(E) bool) S {
var ns S
for _, el := range s {
if f(el) {
ns = append(ns, el)
}
}
return ns
}
// Remove removes e from s.
func Remove[S ~[]E, E comparable](s S, e E) S {
var ns S
for _, el := range s {
if el != e {
ns = append(ns, el)
}
}
return ns
}
// Unique returns the unique elements of s.
func Unique[S ~[]E, E comparable](s S) S {
var ns S
h := map[E]struct{}{}
for _, el := range s {
if _, ok := h[el]; !ok {
h[el] = struct{}{}
ns = append(ns, el)
}
}
return ns
}