diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 4c3fba7a1..523febac1 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -29,7 +29,6 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc/device" "github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" @@ -490,24 +489,40 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { return fmt.Errorf("invalid webauthn url: %w", err) } - var deviceCredentials []*device.Credential + type DeviceCredentialInfo struct { + ID string + } + var currentDeviceCredentials, otherDeviceCredentials []DeviceCredentialInfo for _, id := range pbUser.GetDeviceCredentialIds() { - deviceCredentials = append(deviceCredentials, &device.Credential{ - Id: id, - }) + selected := false + for _, c := range pbSession.GetDeviceCredentials() { + if c.GetId() == id { + selected = true + } + } + if selected { + currentDeviceCredentials = append(currentDeviceCredentials, DeviceCredentialInfo{ + ID: id, + }) + } else { + otherDeviceCredentials = append(otherDeviceCredentials, DeviceCredentialInfo{ + ID: id, + }) + } } input := map[string]interface{}{ - "IsImpersonated": isImpersonated, - "State": s, // local session state (cookie, header, etc) - "Session": pbSession, // current access, refresh, id token - "User": pbUser, // user details inferred from oidc id_token - "DeviceCredentials": deviceCredentials, - "DirectoryUser": pbDirectoryUser, // user details inferred from idp directory - "DirectoryGroups": groups, // user's groups inferred from idp directory - "csrfField": csrf.TemplateField(r), - "SignOutURL": signoutURL, - "WebAuthnURL": webAuthnURL, + "IsImpersonated": isImpersonated, + "State": s, // local session state (cookie, header, etc) + "Session": pbSession, // current access, refresh, id token + "User": pbUser, // user details inferred from oidc id_token + "CurrentDeviceCredentials": currentDeviceCredentials, + "OtherDeviceCredentials": otherDeviceCredentials, + "DirectoryUser": pbDirectoryUser, // user details inferred from idp directory + "DirectoryGroups": groups, // user's groups inferred from idp directory + "csrfField": csrf.TemplateField(r), + "SignOutURL": signoutURL, + "WebAuthnURL": webAuthnURL, } return a.templates.ExecuteTemplate(w, "userInfo.html", input) } diff --git a/authenticate/handlers/webauthn/helpers.go b/authenticate/handlers/webauthn/helpers.go new file mode 100644 index 000000000..319e8952f --- /dev/null +++ b/authenticate/handlers/webauthn/helpers.go @@ -0,0 +1,32 @@ +package webauthn + +import "github.com/pomerium/pomerium/pkg/grpc/session" + +func containsString(elements []string, value string) bool { + for _, element := range elements { + if element == value { + return true + } + } + return false +} + +func removeString(elements []string, value string) []string { + dup := make([]string, 0, len(elements)) + for _, element := range elements { + if element != value { + dup = append(dup, element) + } + } + return dup +} + +func removeSessionDeviceCredential(elements []*session.Session_DeviceCredential, id string) []*session.Session_DeviceCredential { + dup := make([]*session.Session_DeviceCredential, 0, len(elements)) + for _, element := range elements { + if element.GetId() != id { + dup = append(dup, element) + } + } + return dup +} diff --git a/authenticate/handlers/webauthn/webauthn.go b/authenticate/handlers/webauthn/webauthn.go index 389f42406..bf95f9ad2 100644 --- a/authenticate/handlers/webauthn/webauthn.go +++ b/authenticate/handlers/webauthn/webauthn.go @@ -37,8 +37,14 @@ import ( const maxAuthenticateResponses = 5 var ( - errMissingDeviceType = httputil.NewError(http.StatusBadRequest, errors.New("device_type is a required parameter")) - errMissingRedirectURI = httputil.NewError(http.StatusBadRequest, errors.New("pomerium_redirect_uri is a required parameter")) + errMissingDeviceCredentialID = httputil.NewError(http.StatusBadRequest, errors.New( + urlutil.QueryDeviceCredentialID+" is a required parameter")) + errMissingDeviceType = httputil.NewError(http.StatusBadRequest, errors.New( + urlutil.QueryDeviceType+" is a required parameter")) + errMissingRedirectURI = httputil.NewError(http.StatusBadRequest, errors.New( + urlutil.QueryRedirectURI+" is a required parameter")) + errInvalidDeviceCredential = httputil.NewError(http.StatusBadRequest, errors.New( + "invalid device credential")) ) // State is the state needed by the Handler to handle requests. @@ -91,6 +97,8 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) error { return h.handleAuthenticate(w, r, s) case r.FormValue("action") == "register": return h.handleRegister(w, r, s) + case r.FormValue("action") == "unregister": + return h.handleUnregister(w, r, s) } return httputil.NewError(http.StatusNotFound, errors.New(http.StatusText(http.StatusNotFound))) @@ -297,6 +305,49 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state * return h.saveSessionAndRedirect(w, r, state, redirectURIParam) } +func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state *State) error { + ctx := r.Context() + + // get the user information + u, err := user.Get(ctx, state.Client, state.Session.GetUserId()) + if err != nil { + return err + } + + deviceCredentialID := r.FormValue(urlutil.QueryDeviceCredentialID) + if deviceCredentialID == "" { + return errMissingDeviceCredentialID + } + + // ensure we only allow removing a device credential the user owns + if !containsString(u.GetDeviceCredentialIds(), deviceCredentialID) { + return errInvalidDeviceCredential + } + + // delete the credential + deviceCredential, err := device.DeleteCredential(ctx, state.Client, deviceCredentialID) + if err != nil { + return err + } + + // delete the corresponding enrollment + _, err = device.DeleteEnrollment(ctx, state.Client, deviceCredential.GetEnrollmentId()) + if err != nil { + return err + } + + // remove the credential from the user + u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID) + _, err = user.Put(ctx, state.Client, u) + if err != nil { + return err + } + + // remove the credential from the session + state.Session.DeviceCredentials = removeSessionDeviceCredential(state.Session.DeviceCredentials, deviceCredentialID) + return h.saveSessionAndRedirect(w, r, state, "/.pomerium") +} + func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *State) error { ctx := r.Context() diff --git a/internal/frontend/assets/html/userInfo.html b/internal/frontend/assets/html/userInfo.html index d010a4066..4615cc84e 100644 --- a/internal/frontend/assets/html/userInfo.html +++ b/internal/frontend/assets/html/userInfo.html @@ -182,10 +182,10 @@
- Device Credentials + Current Session Device Credentials
- {{if .DeviceCredentials}} + {{if .CurrentDeviceCredentials}} @@ -193,9 +193,17 @@ - {{range .DeviceCredentials}} + {{range .CurrentDeviceCredentials}} - + + {{end}} @@ -204,6 +212,35 @@ No device credentials found! {{end}} + {{if .OtherDeviceCredentials}} +
+
+ Other Device Credentials +
+
{{.Id}}{{.ID}} +
+ {{$.csrfField}} + + + +
+
+ + + + + + + {{range .OtherDeviceCredentials}} + + + + + {{end}} + +
ID
{{.ID}} +
+ {{$.csrfField}} + + + +
+
+
+ {{end}} diff --git a/internal/frontend/assets/style/main.css b/internal/frontend/assets/style/main.css index 86cadb832..67d2cf22b 100644 --- a/internal/frontend/assets/style/main.css +++ b/internal/frontend/assets/style/main.css @@ -147,6 +147,9 @@ body { .box-inner { padding: 35px; } +.box-inner ~ .box-inner { + padding-top: 0; +} .white { background: white; diff --git a/internal/frontend/templates_test.go b/internal/frontend/templates_test.go index f0dce44b6..f88a2e2b0 100644 --- a/internal/frontend/templates_test.go +++ b/internal/frontend/templates_test.go @@ -16,9 +16,5 @@ func TestNewTemplates(t *testing.T) { err = tpl.ExecuteTemplate(&buf, "header.html", nil) require.NoError(t, err) - assert.Equal(t, ` - - - -`, buf.String()) + assert.Contains(t, buf.String(), ``) } diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index defdb6041..72aed1d6c 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -4,16 +4,17 @@ package urlutil // services over HTTP calls and redirects. They are typically used in // conjunction with a HMAC to ensure authenticity. const ( - QueryCallbackURI = "pomerium_callback_uri" - QueryDeviceType = "pomerium_device_type" - QueryEnrollmentToken = "pomerium_enrollment_token" //nolint - QueryIsProgrammatic = "pomerium_programmatic" - QueryForwardAuth = "pomerium_forward_auth" - QueryPomeriumJWT = "pomerium_jwt" - QuerySession = "pomerium_session" - QuerySessionEncrypted = "pomerium_session_encrypted" - QueryRedirectURI = "pomerium_redirect_uri" - QueryForwardAuthURI = "uri" + QueryCallbackURI = "pomerium_callback_uri" + QueryDeviceCredentialID = "pomerium_device_credential_id" + QueryDeviceType = "pomerium_device_type" + QueryEnrollmentToken = "pomerium_enrollment_token" //nolint + QueryIsProgrammatic = "pomerium_programmatic" + QueryForwardAuth = "pomerium_forward_auth" + QueryPomeriumJWT = "pomerium_jwt" + QuerySession = "pomerium_session" + QuerySessionEncrypted = "pomerium_session_encrypted" + QueryRedirectURI = "pomerium_redirect_uri" + QueryForwardAuthURI = "uri" ) // URL signature based query params used for verifying the authenticity of a URL. diff --git a/pkg/grpc/device/device.go b/pkg/grpc/device/device.go index 4c212803e..0c51d72e8 100644 --- a/pkg/grpc/device/device.go +++ b/pkg/grpc/device/device.go @@ -5,11 +5,65 @@ import ( "context" "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/pomerium/pkg/encoding/base58" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" ) +// DeleteCredential deletes a credential from the databroker. +func DeleteCredential( + ctx context.Context, + client databroker.DataBrokerServiceClient, + credentialID string, +) (*Credential, error) { + credential, err := GetCredential(ctx, client, credentialID) + if status.Code(err) == codes.NotFound { + return nil, nil + } else if err != nil { + return nil, err + } + + any := protoutil.NewAny(credential) + _, err = client.Put(ctx, &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.GetTypeUrl(), + Id: credentialID, + Data: any, + DeletedAt: timestamppb.Now(), + }, + }) + return credential, err +} + +// DeleteEnrollment deletes an enrollment from the databroker. +func DeleteEnrollment( + ctx context.Context, + client databroker.DataBrokerServiceClient, + enrollmentID string, +) (*Enrollment, error) { + enrollment, err := GetEnrollment(ctx, client, enrollmentID) + if status.Code(err) == codes.NotFound { + return nil, nil + } else if err != nil { + return nil, err + } + + any := protoutil.NewAny(enrollment) + _, err = client.Put(ctx, &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.GetTypeUrl(), + Id: enrollmentID, + Data: any, + DeletedAt: timestamppb.Now(), + }, + }) + return enrollment, err +} + // GetCredential gets a credential from the databroker. func GetCredential( ctx context.Context,