mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-07 05:12:45 +02:00
device: add generic methods for working with user+session devices (#3710)
This commit is contained in:
parent
6a9d6e45e1
commit
3f9dfbef76
5 changed files with 75 additions and 36 deletions
|
@ -1,32 +0,0 @@
|
||||||
package webauthn
|
|
||||||
|
|
||||||
import "github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
|
|
||||||
func containsString(elements []string, value string) bool {
|
|
||||||
for _, element := range elements {
|
|
||||||
if element == value {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeString(elements []string, value string) []string {
|
|
||||||
dup := make([]string, 0, len(elements))
|
|
||||||
for _, element := range elements {
|
|
||||||
if element != value {
|
|
||||||
dup = append(dup, element)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dup
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeSessionDeviceCredential(elements []*session.Session_DeviceCredential, id string) []*session.Session_DeviceCredential {
|
|
||||||
dup := make([]*session.Session_DeviceCredential, 0, len(elements))
|
|
||||||
for _, element := range elements {
|
|
||||||
if element.GetId() != id {
|
|
||||||
dup = append(dup, element)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dup
|
|
||||||
}
|
|
|
@ -324,7 +324,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the user
|
// save the user
|
||||||
u.DeviceCredentialIds = append(u.DeviceCredentialIds, deviceCredential.GetId())
|
u.AddDeviceCredentialID(deviceCredential.GetId())
|
||||||
_, err = databroker.Put(ctx, state.Client, u)
|
_, err = databroker.Put(ctx, state.Client, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -356,7 +356,7 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure we only allow removing a device credential the user owns
|
// ensure we only allow removing a device credential the user owns
|
||||||
if !containsString(u.GetDeviceCredentialIds(), deviceCredentialID) {
|
if !u.HasDeviceCredentialID(deviceCredentialID) {
|
||||||
return errInvalidDeviceCredential
|
return errInvalidDeviceCredential
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,14 +373,14 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the credential from the user
|
// remove the credential from the user
|
||||||
u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID)
|
u.RemoveDeviceCredentialID(deviceCredentialID)
|
||||||
_, err = databroker.Put(ctx, state.Client, u)
|
_, err = databroker.Put(ctx, state.Client, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the credential from the session
|
// remove the credential from the session
|
||||||
state.Session.DeviceCredentials = removeSessionDeviceCredential(state.Session.DeviceCredentials, deviceCredentialID)
|
state.Session.RemoveDeviceCredentialID(deviceCredentialID)
|
||||||
return h.saveSessionAndRedirect(w, r, state, urlutil.GetAbsoluteURL(r).ResolveReference(&url.URL{
|
return h.saveSessionAndRedirect(w, r, state, urlutil.GetAbsoluteURL(r).ResolveReference(&url.URL{
|
||||||
Path: "/.pomerium",
|
Path: "/.pomerium",
|
||||||
}).String())
|
}).String())
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Delete deletes a session from the databroker.
|
// Delete deletes a session from the databroker.
|
||||||
|
@ -78,3 +79,10 @@ func (x *Session) SetRawIDToken(rawIDToken string) {
|
||||||
}
|
}
|
||||||
x.IdToken.Raw = rawIDToken
|
x.IdToken.Raw = rawIDToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveDeviceCredentialID removes a device credential id.
|
||||||
|
func (x *Session) RemoveDeviceCredentialID(deviceCredentialID string) {
|
||||||
|
x.DeviceCredentials = slices.Filter(x.DeviceCredentials, func(el *Session_DeviceCredential) bool {
|
||||||
|
return el.GetId() != deviceCredentialID
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Get gets a user from the databroker.
|
// Get gets a user from the databroker.
|
||||||
|
@ -47,3 +48,18 @@ func (x *User) GetClaim(claim string) []interface{} {
|
||||||
}
|
}
|
||||||
return vs
|
return vs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddDeviceCredentialID adds a device credential id to the list of device credential ids.
|
||||||
|
func (x *User) AddDeviceCredentialID(deviceCredentialID string) {
|
||||||
|
x.DeviceCredentialIds = slices.Unique(append(x.DeviceCredentialIds, deviceCredentialID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasDeviceCredentialID returns true if the user has the device credential id.
|
||||||
|
func (x *User) HasDeviceCredentialID(deviceCredentialID string) bool {
|
||||||
|
return slices.Contains(x.DeviceCredentialIds, deviceCredentialID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDeviceCredentialID removes the device credential id from the list of device credential ids.
|
||||||
|
func (x *User) RemoveDeviceCredentialID(deviceCredentialID string) {
|
||||||
|
x.DeviceCredentialIds = slices.Remove(x.DeviceCredentialIds, deviceCredentialID)
|
||||||
|
}
|
||||||
|
|
47
pkg/slices/slices.go
Normal file
47
pkg/slices/slices.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
// Package slices contains functions for working with slices.
|
||||||
|
package slices
|
||||||
|
|
||||||
|
// Contains returns true if e is in s.
|
||||||
|
func Contains[S ~[]E, E comparable](s S, e E) bool {
|
||||||
|
for _, el := range s {
|
||||||
|
if el == e {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter returns a new slice containing only those elements for which f(element) is true.
|
||||||
|
func Filter[S ~[]E, E any](s S, f func(E) bool) S {
|
||||||
|
var ns S
|
||||||
|
for _, el := range s {
|
||||||
|
if f(el) {
|
||||||
|
ns = append(ns, el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes e from s.
|
||||||
|
func Remove[S ~[]E, E comparable](s S, e E) S {
|
||||||
|
var ns S
|
||||||
|
for _, el := range s {
|
||||||
|
if el != e {
|
||||||
|
ns = append(ns, el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique returns the unique elements of s.
|
||||||
|
func Unique[S ~[]E, E comparable](s S) S {
|
||||||
|
var ns S
|
||||||
|
h := map[E]struct{}{}
|
||||||
|
for _, el := range s {
|
||||||
|
if _, ok := h[el]; !ok {
|
||||||
|
h[el] = struct{}{}
|
||||||
|
ns = append(ns, el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ns
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue