mirror of
https://github.com/pushbits/server.git
synced 2025-04-29 02:07:38 +02:00
155 lines
3.6 KiB
Go
155 lines
3.6 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/pushbits/server/internal/authentication/credentials"
|
|
"github.com/pushbits/server/internal/configuration"
|
|
"github.com/pushbits/server/internal/log"
|
|
"github.com/pushbits/server/internal/model"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Database holds information for the database connection.
|
|
type Database struct {
|
|
gormdb *gorm.DB
|
|
sqldb *sql.DB
|
|
credentialsManager *credentials.Manager
|
|
}
|
|
|
|
func createFileDir(file string) {
|
|
dir := filepath.Dir(file)
|
|
|
|
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
|
// nosemgrep: tests.semgrep-rules.go.lang.correctness.permissions.incorrect-default-permission
|
|
if err := os.MkdirAll(dir, 0o750); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create instanciates a database connection.
|
|
func Create(cm *credentials.Manager, dialect, connection string) (*Database, error) {
|
|
log.L.Println("Setting up database connection.")
|
|
|
|
maxOpenConns := 5
|
|
|
|
var db *gorm.DB
|
|
var err error
|
|
|
|
switch dialect {
|
|
case "sqlite3":
|
|
createFileDir(connection)
|
|
maxOpenConns = 1
|
|
db, err = gorm.Open(sqlite.Open(connection), &gorm.Config{})
|
|
case "mysql":
|
|
db, err = gorm.Open(mysql.Open(connection), &gorm.Config{})
|
|
default:
|
|
message := "Database dialect is not supported"
|
|
return nil, errors.New(message)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sql, err := db.DB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sql.SetMaxOpenConns(maxOpenConns)
|
|
|
|
if dialect == "mysql" {
|
|
sql.SetConnMaxLifetime(9 * time.Minute)
|
|
}
|
|
|
|
err = db.AutoMigrate(&model.User{}, &model.Application{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Database{gormdb: db, sqldb: sql, credentialsManager: cm}, nil
|
|
}
|
|
|
|
// Close closes the database connection.
|
|
func (d *Database) Close() {
|
|
err := d.sqldb.Close()
|
|
if err != nil {
|
|
log.L.Printf("Error while closing database: %s", err)
|
|
}
|
|
}
|
|
|
|
// Populate fills the database with initial information like the admin user.
|
|
func (d *Database) Populate(name, password, matrixID string) error {
|
|
log.L.Print("Populating database.")
|
|
|
|
var user model.User
|
|
|
|
query := d.gormdb.Where("name = ?", name).First(&user)
|
|
|
|
if errors.Is(query.Error, gorm.ErrRecordNotFound) {
|
|
user, err := model.NewUser(d.credentialsManager, name, password, true, matrixID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := d.gormdb.Create(&user).Error; err != nil {
|
|
return errors.New("user cannot be created")
|
|
}
|
|
} else {
|
|
log.L.Printf("Priviledged user %s already exists.", name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RepairChannels resets channels that have been modified by a user.
|
|
func (d *Database) RepairChannels(dp Dispatcher, behavior *configuration.RepairBehavior) error {
|
|
log.L.Print("Repairing application channels.")
|
|
|
|
users, err := d.GetUsers()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, user := range users {
|
|
user := user // See https://stackoverflow.com/a/68247837
|
|
|
|
applications, err := d.GetApplications(&user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, application := range applications {
|
|
application := application // See https://stackoverflow.com/a/68247837
|
|
|
|
if err := dp.UpdateApplication(&application, behavior); err != nil {
|
|
return err
|
|
}
|
|
|
|
orphan, err := dp.IsOrphan(&application, &user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if orphan {
|
|
log.L.Printf("Found orphan channel for application %s (ID %d)", application.Name, application.ID)
|
|
|
|
if err = dp.RepairApplication(&application, &user); err != nil {
|
|
log.L.Printf("Unable to repair application %s (ID %d).", application.Name, application.ID)
|
|
log.L.Println(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|