mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 21:04:39 +02:00
device: add generic methods for working with user+session devices (#3710)
This commit is contained in:
parent
6a9d6e45e1
commit
3f9dfbef76
5 changed files with 75 additions and 36 deletions
|
@ -1,32 +0,0 @@
|
|||
package webauthn
|
||||
|
||||
import "github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
||||
func containsString(elements []string, value string) bool {
|
||||
for _, element := range elements {
|
||||
if element == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func removeString(elements []string, value string) []string {
|
||||
dup := make([]string, 0, len(elements))
|
||||
for _, element := range elements {
|
||||
if element != value {
|
||||
dup = append(dup, element)
|
||||
}
|
||||
}
|
||||
return dup
|
||||
}
|
||||
|
||||
func removeSessionDeviceCredential(elements []*session.Session_DeviceCredential, id string) []*session.Session_DeviceCredential {
|
||||
dup := make([]*session.Session_DeviceCredential, 0, len(elements))
|
||||
for _, element := range elements {
|
||||
if element.GetId() != id {
|
||||
dup = append(dup, element)
|
||||
}
|
||||
}
|
||||
return dup
|
||||
}
|
|
@ -324,7 +324,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
|
|||
}
|
||||
|
||||
// save the user
|
||||
u.DeviceCredentialIds = append(u.DeviceCredentialIds, deviceCredential.GetId())
|
||||
u.AddDeviceCredentialID(deviceCredential.GetId())
|
||||
_, err = databroker.Put(ctx, state.Client, u)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -356,7 +356,7 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
|
|||
}
|
||||
|
||||
// ensure we only allow removing a device credential the user owns
|
||||
if !containsString(u.GetDeviceCredentialIds(), deviceCredentialID) {
|
||||
if !u.HasDeviceCredentialID(deviceCredentialID) {
|
||||
return errInvalidDeviceCredential
|
||||
}
|
||||
|
||||
|
@ -373,14 +373,14 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
|
|||
}
|
||||
|
||||
// remove the credential from the user
|
||||
u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID)
|
||||
u.RemoveDeviceCredentialID(deviceCredentialID)
|
||||
_, err = databroker.Put(ctx, state.Client, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the credential from the session
|
||||
state.Session.DeviceCredentials = removeSessionDeviceCredential(state.Session.DeviceCredentials, deviceCredentialID)
|
||||
state.Session.RemoveDeviceCredentialID(deviceCredentialID)
|
||||
return h.saveSessionAndRedirect(w, r, state, urlutil.GetAbsoluteURL(r).ResolveReference(&url.URL{
|
||||
Path: "/.pomerium",
|
||||
}).String())
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
)
|
||||
|
||||
// Delete deletes a session from the databroker.
|
||||
|
@ -78,3 +79,10 @@ func (x *Session) SetRawIDToken(rawIDToken string) {
|
|||
}
|
||||
x.IdToken.Raw = rawIDToken
|
||||
}
|
||||
|
||||
// RemoveDeviceCredentialID removes a device credential id.
|
||||
func (x *Session) RemoveDeviceCredentialID(deviceCredentialID string) {
|
||||
x.DeviceCredentials = slices.Filter(x.DeviceCredentials, func(el *Session_DeviceCredential) bool {
|
||||
return el.GetId() != deviceCredentialID
|
||||
})
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
)
|
||||
|
||||
// Get gets a user from the databroker.
|
||||
|
@ -47,3 +48,18 @@ func (x *User) GetClaim(claim string) []interface{} {
|
|||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
// AddDeviceCredentialID adds a device credential id to the list of device credential ids.
|
||||
func (x *User) AddDeviceCredentialID(deviceCredentialID string) {
|
||||
x.DeviceCredentialIds = slices.Unique(append(x.DeviceCredentialIds, deviceCredentialID))
|
||||
}
|
||||
|
||||
// HasDeviceCredentialID returns true if the user has the device credential id.
|
||||
func (x *User) HasDeviceCredentialID(deviceCredentialID string) bool {
|
||||
return slices.Contains(x.DeviceCredentialIds, deviceCredentialID)
|
||||
}
|
||||
|
||||
// RemoveDeviceCredentialID removes the device credential id from the list of device credential ids.
|
||||
func (x *User) RemoveDeviceCredentialID(deviceCredentialID string) {
|
||||
x.DeviceCredentialIds = slices.Remove(x.DeviceCredentialIds, deviceCredentialID)
|
||||
}
|
||||
|
|
47
pkg/slices/slices.go
Normal file
47
pkg/slices/slices.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
// Package slices contains functions for working with slices.
|
||||
package slices
|
||||
|
||||
// Contains returns true if e is in s.
|
||||
func Contains[S ~[]E, E comparable](s S, e E) bool {
|
||||
for _, el := range s {
|
||||
if el == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Filter returns a new slice containing only those elements for which f(element) is true.
|
||||
func Filter[S ~[]E, E any](s S, f func(E) bool) S {
|
||||
var ns S
|
||||
for _, el := range s {
|
||||
if f(el) {
|
||||
ns = append(ns, el)
|
||||
}
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// Remove removes e from s.
|
||||
func Remove[S ~[]E, E comparable](s S, e E) S {
|
||||
var ns S
|
||||
for _, el := range s {
|
||||
if el != e {
|
||||
ns = append(ns, el)
|
||||
}
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// Unique returns the unique elements of s.
|
||||
func Unique[S ~[]E, E comparable](s S) S {
|
||||
var ns S
|
||||
h := map[E]struct{}{}
|
||||
for _, el := range s {
|
||||
if _, ok := h[el]; !ok {
|
||||
h[el] = struct{}{}
|
||||
ns = append(ns, el)
|
||||
}
|
||||
}
|
||||
return ns
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue