From 2e2326843fe515a30deeec1117273cadd6df0d5d Mon Sep 17 00:00:00 2001 From: eikendev Date: Tue, 9 Feb 2021 00:05:16 +0100 Subject: [PATCH] Repair channels if necessary --- cmd/pushbits/main.go | 13 ++++-- internal/api/application.go | 21 ++++------ internal/api/interfaces.go | 2 +- internal/api/notification.go | 2 +- internal/api/user.go | 6 +-- internal/authentication/credentials/hibp.go | 2 +- internal/database/database.go | 45 ++++++++++++++++++++- internal/database/interfaces.go | 13 ++++++ internal/dispatcher/application.go | 33 +++++++++++++-- internal/dispatcher/dispatcher.go | 4 +- internal/dispatcher/notification.go | 2 +- internal/model/user.go | 2 +- 12 files changed, 114 insertions(+), 31 deletions(-) create mode 100644 internal/database/interfaces.go diff --git a/cmd/pushbits/main.go b/cmd/pushbits/main.go index 84a5a5d..66bcaae 100644 --- a/cmd/pushbits/main.go +++ b/cmd/pushbits/main.go @@ -32,29 +32,34 @@ func main() { c := configuration.Get() if c.Debug { - log.Printf("%+v\n", c) + log.Printf("%+v", c) } cm := credentials.CreateManager(c.Security.CheckHIBP, c.Crypto) db, err := database.Create(cm, c.Database.Dialect, c.Database.Connection) if err != nil { - panic(err) + log.Fatal(err) } defer db.Close() if err := db.Populate(c.Admin.Name, c.Admin.Password, c.Admin.MatrixID); err != nil { - panic(err) + log.Fatal(err) } dp, err := dispatcher.Create(db, c.Matrix.Homeserver, c.Matrix.Username, c.Matrix.Password) if err != nil { - panic(err) + log.Fatal(err) } defer dp.Close() setupCleanup(db, dp) + err = db.RepairChannels(dp) + if err != nil { + log.Fatal(err) + } + engine := router.Create(c.Debug, cm, db, dp) runner.Run(engine, c.HTTP.ListenAddress, c.HTTP.Port) diff --git a/internal/api/application.go b/internal/api/application.go index 1ee56a7..c598414 100644 --- a/internal/api/application.go +++ b/internal/api/application.go @@ -27,7 +27,7 @@ func (h *ApplicationHandler) generateToken(compat bool) string { } func (h *ApplicationHandler) registerApplication(ctx *gin.Context, a *model.Application, u *model.User) error { - log.Printf("Registering application %s.\n", a.Name) + log.Printf("Registering application %s.", a.Name) channelID, err := h.DP.RegisterApplication(a.ID, a.Name, a.Token, u.MatrixID) if success := successOrAbort(ctx, http.StatusInternalServerError, err); !success { @@ -41,7 +41,7 @@ func (h *ApplicationHandler) registerApplication(ctx *gin.Context, a *model.Appl } func (h *ApplicationHandler) createApplication(ctx *gin.Context, u *model.User, name string, compat bool) (*model.Application, error) { - log.Printf("Creating application %s.\n", name) + log.Printf("Creating application %s.", name) application := model.Application{} application.Name = name @@ -57,7 +57,7 @@ func (h *ApplicationHandler) createApplication(ctx *gin.Context, u *model.User, err := h.DB.DeleteApplication(&application) if success := successOrAbort(ctx, http.StatusInternalServerError, err); !success { - log.Printf("Cannot delete application with ID %d.\n", application.ID) + log.Printf("Cannot delete application with ID %d.", application.ID) } return nil, err @@ -67,7 +67,7 @@ func (h *ApplicationHandler) createApplication(ctx *gin.Context, u *model.User, } func (h *ApplicationHandler) deleteApplication(ctx *gin.Context, a *model.Application, u *model.User) error { - log.Printf("Deleting application %s (ID %d).\n", a.Name, a.ID) + log.Printf("Deleting application %s (ID %d).", a.Name, a.ID) err := h.DP.DeregisterApplication(a, u) if success := successOrAbort(ctx, http.StatusInternalServerError, err); !success { @@ -82,8 +82,8 @@ func (h *ApplicationHandler) deleteApplication(ctx *gin.Context, a *model.Applic return nil } -func (h *ApplicationHandler) updateApplication(ctx *gin.Context, a *model.Application, u *model.User, updateApplication *model.UpdateApplication) error { - log.Printf("Updating application %s (ID %d).\n", a.Name, a.ID) +func (h *ApplicationHandler) updateApplication(ctx *gin.Context, a *model.Application, updateApplication *model.UpdateApplication) error { + log.Printf("Updating application %s (ID %d).", a.Name, a.ID) if updateApplication.Name != nil { log.Printf("Updating application name to '%s'.", *updateApplication.Name) @@ -101,7 +101,7 @@ func (h *ApplicationHandler) updateApplication(ctx *gin.Context, a *model.Applic return err } - err = h.DP.UpdateApplication(a, u) + err = h.DP.UpdateApplication(a) if success := successOrAbort(ctx, http.StatusInternalServerError, err); !success { return err } @@ -200,12 +200,7 @@ func (h *ApplicationHandler) UpdateApplication(ctx *gin.Context) { return } - user := authentication.GetUser(ctx) - if user == nil { - return - } - - if err := h.updateApplication(ctx, application, user, &updateApplication); err != nil { + if err := h.updateApplication(ctx, application, &updateApplication); err != nil { return } diff --git a/internal/api/interfaces.go b/internal/api/interfaces.go index 73486c0..4bc72b1 100644 --- a/internal/api/interfaces.go +++ b/internal/api/interfaces.go @@ -28,7 +28,7 @@ type Database interface { type Dispatcher interface { RegisterApplication(id uint, name, token, user string) (string, error) DeregisterApplication(a *model.Application, u *model.User) error - UpdateApplication(a *model.Application, u *model.User) error + UpdateApplication(a *model.Application) error } // The CredentialsManager interface for updating credentials. diff --git a/internal/api/notification.go b/internal/api/notification.go index 553bf1c..bfb8af0 100644 --- a/internal/api/notification.go +++ b/internal/api/notification.go @@ -36,7 +36,7 @@ func (h *NotificationHandler) CreateNotification(ctx *gin.Context) { } application := authentication.GetApplication(ctx) - log.Printf("Sending notification for application %s.\n", application.Name) + log.Printf("Sending notification for application %s.", application.Name) notification.ID = 0 notification.ApplicationID = application.ID diff --git a/internal/api/user.go b/internal/api/user.go index ea06ff2..65e5159 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -84,7 +84,7 @@ func (h *UserHandler) updateUser(ctx *gin.Context, u *model.User, updateUser mod } } - log.Printf("Updating user %s.\n", u.Name) + log.Printf("Updating user %s.", u.Name) if updateUser.Name != nil { u.Name = *updateUser.Name @@ -126,7 +126,7 @@ func (h *UserHandler) CreateUser(ctx *gin.Context) { return } - log.Printf("Creating user %s.\n", createUser.Name) + log.Printf("Creating user %s.", createUser.Name) user, err := h.DB.CreateUser(createUser) @@ -181,7 +181,7 @@ func (h *UserHandler) DeleteUser(ctx *gin.Context) { } } - log.Printf("Deleting user %s.\n", user.Name) + log.Printf("Deleting user %s.", user.Name) if err := h.deleteApplications(ctx, user); err != nil { return diff --git a/internal/authentication/credentials/hibp.go b/internal/authentication/credentials/hibp.go index 5836126..81680a8 100644 --- a/internal/authentication/credentials/hibp.go +++ b/internal/authentication/credentials/hibp.go @@ -26,7 +26,7 @@ func IsPasswordPwned(password string) (bool, error) { lookup := hashStr[0:5] match := hashStr[5:] - log.Printf("Checking HIBP for hashes starting with '%s'.\n", lookup) + log.Printf("Checking HIBP for hashes starting with '%s'.", lookup) resp, err := http.Get(pwnedHashesURL + lookup) if err != nil { diff --git a/internal/database/database.go b/internal/database/database.go index b4bc3d8..c7f52ff 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -79,6 +79,8 @@ func (d *Database) Close() { // Populate fills the database with initial information like the admin user. func (d *Database) Populate(name, password, matrixID string) error { + log.Print("Populating database.") + var user model.User query := d.gormdb.Where("name = ?", name).First(&user) @@ -86,14 +88,53 @@ func (d *Database) Populate(name, password, matrixID string) error { if errors.Is(query.Error, gorm.ErrRecordNotFound) { user, err := model.NewUser(d.credentialsManager, name, password, true, matrixID) if err != nil { - log.Fatal(err) + return err } if err := d.gormdb.Create(&user).Error; err != nil { return errors.New("user cannot be created") } } else { - log.Printf("Admin user %s already exists.\n", name) + log.Printf("Priviledged user %s already exists.", name) + } + + return nil +} + +// RepairChannels resets channels that have been modified by a user. +func (d *Database) RepairChannels(dp Dispatcher) error { + log.Print("Repairing application channels.") + + users, err := d.GetUsers() + if err != nil { + return err + } + + for _, user := range users { + applications, err := d.GetApplications(&user) + if err != nil { + return err + } + + for _, application := range applications { + if err := dp.UpdateApplication(&application); err != nil { + return err + } + + orphan, err := dp.IsOrphan(&application, &user) + if err != nil { + return err + } + + if orphan { + log.Printf("Found orphan channel for application %s (ID %d)", application.Name, application.ID) + + if err = dp.RepairApplication(&application, &user); err != nil { + log.Printf("Unable to repair application %s (ID %d).", application.Name, application.ID) + log.Println(err) + } + } + } } return nil diff --git a/internal/database/interfaces.go b/internal/database/interfaces.go new file mode 100644 index 0000000..7d4a70c --- /dev/null +++ b/internal/database/interfaces.go @@ -0,0 +1,13 @@ +package database + +import ( + "github.com/pushbits/server/internal/model" +) + +// The Dispatcher interface for constructing and destructing channels. +type Dispatcher interface { + DeregisterApplication(a *model.Application, u *model.User) error + UpdateApplication(a *model.Application) error + IsOrphan(a *model.Application, u *model.User) (bool, error) + RepairApplication(a *model.Application, u *model.User) error +} diff --git a/internal/dispatcher/application.go b/internal/dispatcher/application.go index d6a49da..8728a7d 100644 --- a/internal/dispatcher/application.go +++ b/internal/dispatcher/application.go @@ -25,7 +25,6 @@ func (d *Dispatcher) RegisterApplication(id uint, name, token, user string) (str Topic: buildRoomTopic(id), Visibility: "private", }) - if err != nil { log.Print(err) return "", err @@ -45,9 +44,9 @@ func (d *Dispatcher) DeregisterApplication(a *model.Application, u *model.User) UserID: u.MatrixID, } + // The user might have left the channel, but we can still try to remove them. if _, err := d.client.KickUser(a.MatrixID, kickUser); err != nil { log.Print(err) - return err } if _, err := d.client.LeaveRoom(a.MatrixID); err != nil { @@ -73,7 +72,7 @@ func (d *Dispatcher) sendRoomEvent(roomID, eventType string, content interface{} } // UpdateApplication updates a channel for an application. -func (d *Dispatcher) UpdateApplication(a *model.Application, u *model.User) error { +func (d *Dispatcher) UpdateApplication(a *model.Application) error { log.Printf("Updating application %s (ID %d) with Matrix ID %s.\n", a.Name, a.ID, a.MatrixID) content := map[string]interface{}{ @@ -94,3 +93,31 @@ func (d *Dispatcher) UpdateApplication(a *model.Application, u *model.User) erro return nil } + +// IsOrphan checks if the user is still connected to the channel. +func (d *Dispatcher) IsOrphan(a *model.Application, u *model.User) (bool, error) { + resp, err := d.client.JoinedMembers(a.MatrixID) + if err != nil { + return false, err + } + + found := false + + for userID := range resp.Joined { + found = found || (userID == u.MatrixID) + } + + return !found, nil +} + +// RepairApplication re-invites the user to the channel. +func (d *Dispatcher) RepairApplication(a *model.Application, u *model.User) error { + _, err := d.client.InviteUser(a.MatrixID, &gomatrix.ReqInviteUser{ + UserID: u.MatrixID, + }) + if err != nil { + return err + } + + return nil +} diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index baf324e..631a5f6 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -45,8 +45,10 @@ func Create(db Database, homeserver, username, password string) (*Dispatcher, er // Close closes the dispatcher connection. func (d *Dispatcher) Close() { - log.Printf("Logging out.\n") + log.Printf("Logging out.") d.client.Logout() d.client.ClearCredentials() + + log.Printf("Successfully logged out.") } diff --git a/internal/dispatcher/notification.go b/internal/dispatcher/notification.go index 63f8b04..fd26991 100644 --- a/internal/dispatcher/notification.go +++ b/internal/dispatcher/notification.go @@ -11,7 +11,7 @@ import ( // SendNotification sends a notification to the specified user. func (d *Dispatcher) SendNotification(a *model.Application, n *model.Notification) error { - log.Printf("Sending notification to room %s.\n", a.MatrixID) + log.Printf("Sending notification to room %s.", a.MatrixID) plainTitle := strings.TrimSpace(n.Title) plainMessage := strings.TrimSpace(n.Message) diff --git a/internal/model/user.go b/internal/model/user.go index 7c7cbfc..16d6a08 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -37,7 +37,7 @@ type CreateUser struct { // NewUser creates a new user. func NewUser(cm *credentials.Manager, name, password string, isAdmin bool, matrixID string) (*User, error) { - log.Printf("Creating user %s.\n", name) + log.Printf("Creating user %s.", name) passwordHash, err := cm.CreatePasswordHash(password) if err != nil {