device: add type id and credential id to enrollment for easier referencing (#2749)

This commit is contained in:
Caleb Doxsey 2021-11-05 09:48:45 -06:00 committed by GitHub
parent 4cb3281af7
commit 85bb396555
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 55 deletions

View file

@ -252,14 +252,16 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("error verifying registration: %w", err))
}
deviceEnrollment, err := getOrCreateDeviceEnrollment(ctx, r, state, u)
deviceCredentialID := webauthnutil.GetDeviceCredentialID(serverCredential.ID)
deviceEnrollment, err := getOrCreateDeviceEnrollment(ctx, r, state, deviceType.GetId(), deviceCredentialID, u)
if err != nil {
return err
}
// save the credential
deviceCredential := &device.Credential{
Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID),
Id: deviceCredentialID,
TypeId: deviceType.GetId(),
EnrollmentId: deviceEnrollment.GetId(),
UserId: u.GetId(),
@ -413,6 +415,8 @@ func getOrCreateDeviceEnrollment(
ctx context.Context,
r *http.Request,
state *State,
deviceTypeID string,
deviceCredentialID string,
u *user.User,
) (*device.Enrollment, error) {
var deviceEnrollment *device.Enrollment
@ -422,6 +426,7 @@ func getOrCreateDeviceEnrollment(
// create a new enrollment
deviceEnrollment = &device.Enrollment{
Id: uuid.New().String(),
TypeId: deviceTypeID,
UserId: u.GetId(),
}
} else {
@ -436,6 +441,10 @@ func getOrCreateDeviceEnrollment(
return nil, err
}
if deviceEnrollment.GetTypeId() != deviceTypeID {
return nil, httputil.NewError(http.StatusForbidden, fmt.Errorf("invalid enrollment token: wrong device type"))
}
if deviceEnrollment.GetUserId() != u.GetId() {
return nil, httputil.NewError(http.StatusForbidden, fmt.Errorf("invalid enrollment token: wrong user id"))
}
@ -445,6 +454,7 @@ func getOrCreateDeviceEnrollment(
}
}
deviceEnrollment.CredentialId = deviceCredentialID
deviceEnrollment.EnrolledAt = timestamppb.Now()
deviceEnrollment.UserAgent = r.UserAgent()
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {