From b0699da1e9693a20fa818f05df17236460bc9fa5 Mon Sep 17 00:00:00 2001 From: eikendev Date: Sun, 16 Feb 2025 00:15:50 +0100 Subject: [PATCH] Add nilaway and rework test setup --- Makefile | 2 + cmd/pushbits/main.go | 27 +++- internal/api/alertmanager/handler.go | 4 + internal/api/api_test.go | 100 ++++++++++++ internal/api/application.go | 22 ++- internal/api/application_test.go | 160 +++++--------------- internal/api/context.go | 5 + internal/api/context_test.go | 41 ++--- internal/api/health_test.go | 4 +- internal/api/notification.go | 8 + internal/api/notification_test.go | 11 +- internal/api/user.go | 55 +++++-- internal/api/user_test.go | 45 +++--- internal/api/util_test.go | 6 +- internal/authentication/credentials/hibp.go | 3 + 15 files changed, 309 insertions(+), 184 deletions(-) create mode 100644 internal/api/api_test.go diff --git a/Makefile b/Makefile index 1acd97a..6a3a5f6 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,7 @@ test: errcheck -exclude errcheck_excludes.txt ./... gocritic check -disable='#experimental,#opinionated' -@ifElseChain.minThreshold 3 ./... revive -set_exit_status -exclude ./docs ./... + nilaway ./... go test -v -cover ./... gosec -exclude-dir=tests ./... govulncheck ./... @@ -43,6 +44,7 @@ setup: go install github.com/mgechev/revive@latest go install github.com/securego/gosec/v2/cmd/gosec@latest go install github.com/swaggo/swag/cmd/swag@latest + go install go.uber.org/nilaway/cmd/nilaway@latest go install golang.org/x/vuln/cmd/govulncheck@latest go install honnef.co/go/tools/cmd/staticcheck@latest go install mvdan.cc/gofumpt@latest diff --git a/cmd/pushbits/main.go b/cmd/pushbits/main.go index 290ffa5..db313b8 100644 --- a/cmd/pushbits/main.go +++ b/cmd/pushbits/main.go @@ -29,6 +29,14 @@ func setupCleanup(db *database.Database, dp *dispatcher.Dispatcher) { }() } +func printStarupMessage() { + if len(version) == 0 { + log.L.Panic("Version not set") + } else { + log.L.Printf("Starting PushBits %s", version) + } +} + // @title PushBits Server API Documentation // @version 0.10.5 // @description Documentation for the PushBits server API. @@ -45,11 +53,7 @@ func setupCleanup(db *database.Database, dp *dispatcher.Dispatcher) { // @securityDefinitions.basic BasicAuth func main() { - if len(version) == 0 { - log.L.Panic("Version not set") - } else { - log.L.Printf("Starting PushBits %s", version) - } + printStarupMessage() c := configuration.Get() @@ -63,6 +67,11 @@ func main() { db, err := database.Create(cm, c.Database.Dialect, c.Database.Connection) if err != nil { log.L.Fatal(err) + return + } + if db == nil { + log.L.Fatal("db is nil but error was nil") + return } defer db.Close() @@ -73,6 +82,11 @@ func main() { dp, err := dispatcher.Create(c.Matrix.Homeserver, c.Matrix.Username, c.Matrix.Password, c.Formatting) if err != nil { log.L.Fatal(err) + return + } + if dp == nil { + log.L.Fatal("dp is nil but error was nil") + return } defer dp.Close() @@ -81,15 +95,18 @@ func main() { err = db.RepairChannels(dp, &c.RepairBehavior) if err != nil { log.L.Fatal(err) + return } engine, err := router.Create(c.Debug, c.HTTP.TrustedProxies, cm, db, dp, &c.Alertmanager) if err != nil { log.L.Fatal(err) + return } err = runner.Run(engine, c) if err != nil { log.L.Fatal(err) + return } } diff --git a/internal/api/alertmanager/handler.go b/internal/api/alertmanager/handler.go index bfc45f0..337a5ad 100644 --- a/internal/api/alertmanager/handler.go +++ b/internal/api/alertmanager/handler.go @@ -38,6 +38,10 @@ type HandlerSettings struct { // @Router /alert [post] func (h *Handler) CreateAlert(ctx *gin.Context) { application := authentication.GetApplication(ctx) + if application == nil { + return + } + log.L.Printf("Sending alert notification for application %s.", application.Name) var hook model.AlertmanagerWebhook diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 0000000..7d8d6fe --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,100 @@ +package api + +import ( + "fmt" + "os" + "testing" + + "github.com/gin-gonic/gin" + "github.com/pushbits/server/internal/authentication/credentials" + "github.com/pushbits/server/internal/configuration" + "github.com/pushbits/server/internal/database" + "github.com/pushbits/server/internal/log" + "github.com/pushbits/server/internal/model" + "github.com/pushbits/server/tests/mockups" +) + +// TestContext holds all test-related objects +type TestContext struct { + ApplicationHandler *ApplicationHandler + Users []*model.User + Database *database.Database + NotificationHandler *NotificationHandler + UserHandler *UserHandler + Config *configuration.Configuration +} + +var GlobalTestContext *TestContext + +func cleanup() { + err := os.Remove("pushbits-test.db") + if err != nil { + log.L.Warnln("Cannot delete test database: ", err) + } +} + +func TestMain(m *testing.M) { + cleanup() + + gin.SetMode(gin.TestMode) + + GlobalTestContext = CreateTestContext(nil) + + m.Run() + + cleanup() +} + +// GetTestContext initializes and verifies all required test components +func GetTestContext(_ *testing.T) *TestContext { + if GlobalTestContext == nil { + GlobalTestContext = CreateTestContext(nil) + } + + return GlobalTestContext +} + +// CreateTestContext initializes and verifies all required test components +func CreateTestContext(_ *testing.T) *TestContext { + ctx := &TestContext{} + + config := configuration.Configuration{} + config.Database.Connection = "pushbits-test.db" + config.Database.Dialect = "sqlite3" + config.Crypto.Argon2.Iterations = 4 + config.Crypto.Argon2.Parallelism = 4 + config.Crypto.Argon2.Memory = 131072 + config.Crypto.Argon2.SaltLength = 16 + config.Crypto.Argon2.KeyLength = 32 + config.Admin.Name = "user" + config.Admin.Password = "pushbits" + ctx.Config = &config + + db, err := mockups.GetEmptyDatabase(ctx.Config.Crypto) + if err != nil { + cleanup() + panic(fmt.Errorf("cannot set up database: %w", err)) + } + ctx.Database = db + + ctx.ApplicationHandler = &ApplicationHandler{ + DB: ctx.Database, + DP: &mockups.MockDispatcher{}, + } + + ctx.Users = mockups.GetUsers(ctx.Config) + + ctx.NotificationHandler = &NotificationHandler{ + DB: ctx.Database, + DP: &mockups.MockDispatcher{}, + } + + ctx.UserHandler = &UserHandler{ + AH: ctx.ApplicationHandler, + CM: credentials.CreateManager(false, ctx.Config.Crypto), + DB: ctx.Database, + DP: &mockups.MockDispatcher{}, + } + + return ctx +} diff --git a/internal/api/application.go b/internal/api/application.go index 774bb01..959b941 100644 --- a/internal/api/application.go +++ b/internal/api/application.go @@ -28,6 +28,10 @@ func (h *ApplicationHandler) generateToken(compat bool) string { } func (h *ApplicationHandler) registerApplication(ctx *gin.Context, a *model.Application, u *model.User) error { + if a == nil || u == nil { + return errors.New("nil parameters provided") + } + log.L.Printf("Registering application %s.", a.Name) channelID, err := h.DP.RegisterApplication(a.ID, a.Name, u.MatrixID) @@ -46,6 +50,10 @@ func (h *ApplicationHandler) registerApplication(ctx *gin.Context, a *model.Appl } func (h *ApplicationHandler) createApplication(ctx *gin.Context, u *model.User, name string, compat bool) (*model.Application, error) { + if u == nil { + return nil, errors.New("nil parameters provided") + } + log.L.Printf("Creating application %s.", name) application := model.Application{} @@ -71,6 +79,10 @@ func (h *ApplicationHandler) createApplication(ctx *gin.Context, u *model.User, } func (h *ApplicationHandler) deleteApplication(ctx *gin.Context, a *model.Application, u *model.User) error { + if a == nil || u == nil { + return errors.New("nil parameters provided") + } + log.L.Printf("Deleting application %s (ID %d).", a.Name, a.ID) err := h.DP.DeregisterApplication(a, u) @@ -87,6 +99,10 @@ func (h *ApplicationHandler) deleteApplication(ctx *gin.Context, a *model.Applic } func (h *ApplicationHandler) updateApplication(ctx *gin.Context, a *model.Application, updateApplication *model.UpdateApplication) error { + if a == nil || updateApplication == nil { + return errors.New("nil parameters provided") + } + log.L.Printf("Updating application %s (ID %d).", a.Name, a.ID) if updateApplication.Name != nil { @@ -186,7 +202,7 @@ func (h *ApplicationHandler) GetApplications(ctx *gin.Context) { // @Router /application/{id} [get] func (h *ApplicationHandler) GetApplication(ctx *gin.Context) { application, err := getApplication(ctx, h.DB) - if err != nil { + if err != nil || application == nil { return } @@ -218,7 +234,7 @@ func (h *ApplicationHandler) GetApplication(ctx *gin.Context) { // @Router /application/{id} [delete] func (h *ApplicationHandler) DeleteApplication(ctx *gin.Context) { application, err := getApplication(ctx, h.DB) - if err != nil { + if err != nil || application == nil { return } @@ -250,7 +266,7 @@ func (h *ApplicationHandler) DeleteApplication(ctx *gin.Context) { // @Router /application/{id} [put] func (h *ApplicationHandler) UpdateApplication(ctx *gin.Context) { application, err := getApplication(ctx, h.DB) - if err != nil { + if err != nil || application == nil { return } diff --git a/internal/api/application_test.go b/internal/api/application_test.go index 600b6e8..1dd101c 100644 --- a/internal/api/application_test.go +++ b/internal/api/application_test.go @@ -4,91 +4,21 @@ import ( "encoding/json" "fmt" "io" - "os" "testing" - "github.com/gin-gonic/gin" - "github.com/pushbits/server/internal/authentication/credentials" - "github.com/pushbits/server/internal/configuration" - "github.com/pushbits/server/internal/database" - "github.com/pushbits/server/internal/log" "github.com/pushbits/server/internal/model" "github.com/pushbits/server/tests" - "github.com/pushbits/server/tests/mockups" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var ( - TestApplicationHandler *ApplicationHandler - TestUsers []*model.User - TestDatabase *database.Database - TestNotificationHandler *NotificationHandler - TestUserHandler *UserHandler - TestConfig *configuration.Configuration -) - // Collect all created applications to check & delete them later -var SuccessAplications map[uint][]model.Application - -func TestMain(m *testing.M) { - cleanUp() - // Get main config and adapt - config := &configuration.Configuration{} - - config.Database.Connection = "pushbits-test.db" - config.Database.Dialect = "sqlite3" - config.Crypto.Argon2.Iterations = 4 - config.Crypto.Argon2.Parallelism = 4 - config.Crypto.Argon2.Memory = 131072 - config.Crypto.Argon2.SaltLength = 16 - config.Crypto.Argon2.KeyLength = 32 - config.Admin.Name = "user" - config.Admin.Password = "pushbits" - - TestConfig = config - - // Set up test environment - db, err := mockups.GetEmptyDatabase(config.Crypto) - if err != nil { - cleanUp() - log.L.Println("Cannot set up database: ", err) - os.Exit(1) - } - TestDatabase = db - - appHandler, err := getApplicationHandler(config) - if err != nil { - cleanUp() - log.L.Println("Cannot set up application handler: ", err) - os.Exit(1) - } - - TestApplicationHandler = appHandler - TestUsers = mockups.GetUsers(config) - SuccessAplications = make(map[uint][]model.Application) - - TestNotificationHandler = &NotificationHandler{ - DB: TestDatabase, - DP: &mockups.MockDispatcher{}, - } - - TestUserHandler = &UserHandler{ - AH: TestApplicationHandler, - CM: credentials.CreateManager(false, config.Crypto), - DB: TestDatabase, - DP: &mockups.MockDispatcher{}, - } - - // Run - m.Run() - log.L.Println("Clean up after Test") - cleanUp() -} +var SuccessApplications = make(map[uint][]model.Application) func TestApi_RegisterApplicationWithoutUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) - gin.SetMode(gin.TestMode) reqWoUser := tests.Request{Name: "Invalid JSON Data", Method: "POST", Endpoint: "/application", Data: `{"name": "test1", "strict_compatibility": true}`, Headers: map[string]string{"Content-Type": "application/json"}} _, c, err := reqWoUser.GetRequest() @@ -96,21 +26,22 @@ func TestApi_RegisterApplicationWithoutUser(t *testing.T) { t.Fatal(err.Error()) } - assert.Panicsf(func() { TestApplicationHandler.CreateApplication(c) }, "CreateApplication did not panic although user is not in context") + assert.Panicsf(func() { ctx.ApplicationHandler.CreateApplication(c) }, "CreateApplication did not panic although user is not in context") } func TestApi_RegisterApplication(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) testCases := make([]tests.Request, 0) testCases = append(testCases, tests.Request{Name: "Invalid Form Data", Method: "POST", Endpoint: "/application", Data: "k=1&v=abc", ShouldStatus: 400}) testCases = append(testCases, tests.Request{Name: "Invalid JSON Data", Method: "POST", Endpoint: "/application", Data: `{"name": "test1", "strict_compatibility": "oh yes"}`, Headers: map[string]string{"Content-Type": "application/json"}, ShouldStatus: 400}) testCases = append(testCases, tests.Request{Name: "Valid JSON Data", Method: "POST", Endpoint: "/application", Data: `{"name": "test2", "strict_compatibility": true}`, Headers: map[string]string{"Content-Type": "application/json"}, ShouldStatus: 200}) - for _, user := range TestUsers { - SuccessAplications[user.ID] = make([]model.Application, 0) + for _, user := range ctx.Users { + SuccessApplications[user.ID] = make([]model.Application, 0) for _, req := range testCases { var application model.Application w, c, err := req.GetRequest() @@ -119,7 +50,7 @@ func TestApi_RegisterApplication(t *testing.T) { } c.Set("user", user) - TestApplicationHandler.CreateApplication(c) + ctx.ApplicationHandler.CreateApplication(c) // Parse body only for successful requests if req.ShouldStatus >= 200 && req.ShouldStatus < 300 { @@ -128,7 +59,7 @@ func TestApi_RegisterApplication(t *testing.T) { err = json.Unmarshal(body, &application) require.NoErrorf(err, "Cannot unmarshal request body") - SuccessAplications[user.ID] = append(SuccessAplications[user.ID], application) + SuccessApplications[user.ID] = append(SuccessApplications[user.ID], application) } assert.Equalf(w.Code, req.ShouldStatus, "CreateApplication (Test case: \"%s\") Expected status code %v but received %v.", req.Name, req.ShouldStatus, w.Code) @@ -137,16 +68,17 @@ func TestApi_RegisterApplication(t *testing.T) { } func TestApi_GetApplications(t *testing.T) { + ctx := GetTestContext(t) + var applications []model.Application assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) testCases := make([]tests.Request, 0) testCases = append(testCases, tests.Request{Name: "Valid Request", Method: "GET", Endpoint: "/application", ShouldStatus: 200}) - for _, user := range TestUsers { + for _, user := range ctx.Users { for _, req := range testCases { w, c, err := req.GetRequest() if err != nil { @@ -154,7 +86,7 @@ func TestApi_GetApplications(t *testing.T) { } c.Set("user", user) - TestApplicationHandler.GetApplications(c) + ctx.ApplicationHandler.GetApplications(c) // Parse body only for successful requests if req.ShouldStatus >= 200 && req.ShouldStatus < 300 { @@ -167,7 +99,7 @@ func TestApi_GetApplications(t *testing.T) { } assert.Truef(validateAllApplications(user, applications), "Did not find application created previously") - assert.Equalf(len(applications), len(SuccessAplications[user.ID]), "Created %d application(s) but got %d back", len(SuccessAplications[user.ID]), len(applications)) + assert.Equalf(len(applications), len(SuccessApplications[user.ID]), "Created %d application(s) but got %d back", len(SuccessApplications[user.ID]), len(applications)) } assert.Equalf(w.Code, req.ShouldStatus, "GetApplications (Test case: \"%s\") Expected status code %v but received %v.", req.Name, req.ShouldStatus, w.Code) @@ -176,8 +108,9 @@ func TestApi_GetApplications(t *testing.T) { } func TestApi_GetApplicationsWithoutUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) - gin.SetMode(gin.TestMode) testCase := tests.Request{Name: "Valid Request", Method: "GET", Endpoint: "/application"} @@ -186,12 +119,13 @@ func TestApi_GetApplicationsWithoutUser(t *testing.T) { t.Fatal(err.Error()) } - assert.Panicsf(func() { TestApplicationHandler.GetApplications(c) }, "GetApplications did not panic although user is not in context") + assert.Panicsf(func() { ctx.ApplicationHandler.GetApplications(c) }, "GetApplications did not panic although user is not in context") } func TestApi_GetApplicationErrors(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) - gin.SetMode(gin.TestMode) // Arbitrary test cases testCases := make(map[uint]tests.Request) @@ -199,7 +133,7 @@ func TestApi_GetApplicationErrors(t *testing.T) { testCases[5555] = tests.Request{Name: "Requesting unknown application 5555", Method: "GET", Endpoint: "/application/5555", ShouldStatus: 404} testCases[99999999999999999] = tests.Request{Name: "Requesting unknown application 99999999999999999", Method: "GET", Endpoint: "/application/99999999999999999", ShouldStatus: 404} - for _, user := range TestUsers { + for _, user := range ctx.Users { for id, req := range testCases { w, c, err := req.GetRequest() if err != nil { @@ -208,7 +142,7 @@ func TestApi_GetApplicationErrors(t *testing.T) { c.Set("user", user) c.Set("id", id) - TestApplicationHandler.GetApplication(c) + ctx.ApplicationHandler.GetApplication(c) assert.Equalf(w.Code, req.ShouldStatus, "GetApplication (Test case: \"%s\") Expected status code %v but have %v.", req.Name, req.ShouldStatus, w.Code) } @@ -216,15 +150,16 @@ func TestApi_GetApplicationErrors(t *testing.T) { } func TestApi_GetApplication(t *testing.T) { + ctx := GetTestContext(t) + var application model.Application assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) // Previously generated applications - for _, user := range TestUsers { - for _, app := range SuccessAplications[user.ID] { + for _, user := range ctx.Users { + for _, app := range SuccessApplications[user.ID] { req := tests.Request{Name: fmt.Sprintf("Requesting application %s (%d)", app.Name, app.ID), Method: "GET", Endpoint: fmt.Sprintf("/application/%d", app.ID), ShouldStatus: 200} w, c, err := req.GetRequest() @@ -234,7 +169,7 @@ func TestApi_GetApplication(t *testing.T) { c.Set("user", user) c.Set("id", app.ID) - TestApplicationHandler.GetApplication(c) + ctx.ApplicationHandler.GetApplication(c) // Parse body only for successful requests if req.ShouldStatus >= 200 && req.ShouldStatus < 300 { @@ -255,14 +190,15 @@ func TestApi_GetApplication(t *testing.T) { } func TestApi_UpdateApplication(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) - for _, user := range TestUsers { + for _, user := range ctx.Users { testCases := make(map[uint]tests.Request) // Previously generated applications - for _, app := range SuccessAplications[user.ID] { + for _, app := range SuccessApplications[user.ID] { newName := app.Name + "-new_name" updateApp := model.UpdateApplication{ Name: &newName, @@ -287,7 +223,7 @@ func TestApi_UpdateApplication(t *testing.T) { c.Set("user", user) c.Set("id", id) - TestApplicationHandler.UpdateApplication(c) + ctx.ApplicationHandler.UpdateApplication(c) assert.Equalf(w.Code, req.ShouldStatus, "UpdateApplication (Test case: \"%s\") Expected status code %v but have %v.", req.Name, req.ShouldStatus, w.Code) } @@ -295,13 +231,14 @@ func TestApi_UpdateApplication(t *testing.T) { } func TestApi_DeleteApplication(t *testing.T) { - assert := assert.New(t) - gin.SetMode(gin.TestMode) + ctx := GetTestContext(t) - for _, user := range TestUsers { + assert := assert.New(t) + + for _, user := range ctx.Users { testCases := make(map[uint]tests.Request) // Previously generated applications - for _, app := range SuccessAplications[user.ID] { + for _, app := range SuccessApplications[user.ID] { testCases[app.ID] = tests.Request{Name: fmt.Sprintf("Delete application %s (%d)", app.Name, app.ID), Method: "DELETE", Endpoint: fmt.Sprintf("/application/%d", app.ID), ShouldStatus: 200} } // Arbitrary test cases @@ -315,30 +252,20 @@ func TestApi_DeleteApplication(t *testing.T) { c.Set("user", user) c.Set("id", id) - TestApplicationHandler.DeleteApplication(c) + ctx.ApplicationHandler.DeleteApplication(c) assert.Equalf(w.Code, req.ShouldStatus, "DeleteApplication (Test case: \"%s\") Expected status code %v but have %v.", req.Name, req.ShouldStatus, w.Code) } } } -// GetApplicationHandler creates and returns an application handler -func getApplicationHandler(_ *configuration.Configuration) (*ApplicationHandler, error) { - dispatcher := &mockups.MockDispatcher{} - - return &ApplicationHandler{ - DB: TestDatabase, - DP: dispatcher, - }, nil -} - // True if all created applications are in list func validateAllApplications(user *model.User, apps []model.Application) bool { - if _, ok := SuccessAplications[user.ID]; !ok { + if _, ok := SuccessApplications[user.ID]; !ok { return len(apps) == 0 } - for _, successApp := range SuccessAplications[user.ID] { + for _, successApp := range SuccessApplications[user.ID] { foundApp := false for _, app := range apps { if app.ID == successApp.ID { @@ -354,10 +281,3 @@ func validateAllApplications(user *model.User, apps []model.Application) bool { return true } - -func cleanUp() { - err := os.Remove("pushbits-test.db") - if err != nil { - log.L.Warnln("Cannot delete test database: ", err) - } -} diff --git a/internal/api/context.go b/internal/api/context.go index 34fe836..59d42ce 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -41,6 +41,11 @@ func getApplication(ctx *gin.Context, db Database) (*model.Application, error) { if success := SuccessOrAbort(ctx, http.StatusNotFound, err); !success { return nil, err } + if application == nil { + err := errors.New("application not found") + ctx.AbortWithError(http.StatusNotFound, err) + return nil, err + } return application, nil } diff --git a/internal/api/context_test.go b/internal/api/context_test.go index 7b31007..a7a077a 100644 --- a/internal/api/context_test.go +++ b/internal/api/context_test.go @@ -3,7 +3,6 @@ package api import ( "testing" - "github.com/gin-gonic/gin" "github.com/pushbits/server/internal/log" "github.com/pushbits/server/internal/model" "github.com/pushbits/server/tests" @@ -13,9 +12,10 @@ import ( ) func TestApi_getID(t *testing.T) { + GetTestContext(t) + assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) testValue := uint(1337) testCases := make(map[interface{}]tests.Request) @@ -54,13 +54,14 @@ func TestApi_getID(t *testing.T) { } func TestApi_getApplication(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) applications := mockups.GetAllApplications() - err := mockups.AddApplicationsToDb(TestDatabase, applications) + err := mockups.AddApplicationsToDb(ctx.Database, applications) if err != nil { log.L.Fatalln("Cannot add mock applications to database: ", err) } @@ -78,26 +79,28 @@ func TestApi_getApplication(t *testing.T) { } c.Set("id", id) - app, err := getApplication(c, TestDatabase) + app, err := getApplication(c, ctx.Database) if req.ShouldStatus >= 200 && req.ShouldStatus < 300 { - require.NoErrorf(err, "getApplication with id %v (%t) returned an error although it should not: %v", id, id, err) - assert.Equalf(app.ID, id, "getApplication id was set to %d but resulting app id is %d", id, app.ID) + require.NoErrorf(err, "getApplication with id %v returned an unexpected error: %v", id, err) + require.NotNilf(app, "Expected a valid app for id %v, but got nil", id) + assert.Equalf(app.ID, id, "Expected app ID %d, but got %d", id, app.ID) } else { - assert.Errorf(err, "getApplication with id %v (%t) returned no error although it should", id, id) + require.Errorf(err, "Expected an error for id %v, but got none", id) + assert.Nilf(app, "Expected app to be nil for id %v, but got %+v", id, app) } - assert.Equalf(w.Code, req.ShouldStatus, "getApplication id was set to %v (%T) and should result in status code %d but code is %d", id, id, req.ShouldStatus, w.Code) - + assert.Equalf(w.Code, req.ShouldStatus, "Expected status code %d for id %v, but got %d", req.ShouldStatus, id, w.Code) } } func TestApi_getUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) - _, err := mockups.AddUsersToDb(TestDatabase, TestUsers) + _, err := mockups.AddUsersToDb(ctx.Database, ctx.Users) assert.NoErrorf(err, "Adding users to database failed: %v", err) // No testing of invalid ids as that is tested in TestApi_getID already @@ -113,16 +116,18 @@ func TestApi_getUser(t *testing.T) { } c.Set("id", id) - user, err := getUser(c, TestDatabase) + user, err := getUser(c, ctx.Database) if req.ShouldStatus >= 200 && req.ShouldStatus < 300 { - require.NoErrorf(err, "getUser with id %v (%t) returned an error although it should not: %v", id, id, err) - assert.Equalf(user.ID, id, "getUser id was set to %d but resulting app id is %d", id, user.ID) - + require.NoErrorf(err, "getUser with id %v returned an unexpected error: %v", id, err) + require.NotNilf(user, "Expected a valid user for id %v, but got nil", id) + assert.Equalf(user.ID, id, "Expected user ID %d, but got %d", id, user.ID) } else { - assert.Errorf(err, "getUser with id %v (%t) returned no error although it should", id, id) + require.Errorf(err, "Expected an error for id %v, but got none", id) + assert.Nilf(user, "Expected user to be nil for id %v, but got %+v", id, user) } - assert.Equalf(w.Code, req.ShouldStatus, "getUser id was set to %v (%T) and should result in status code %d but code is %d", id, id, req.ShouldStatus, w.Code) + assert.Equalf(w.Code, req.ShouldStatus, "Expected status code %d for id %v, but got %d", req.ShouldStatus, id, w.Code) + } } diff --git a/internal/api/health_test.go b/internal/api/health_test.go index 3a6e4f0..1b147d7 100644 --- a/internal/api/health_test.go +++ b/internal/api/health_test.go @@ -8,9 +8,11 @@ import ( ) func TestApi_Health(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) handler := HealthHandler{ - DB: TestDatabase, + DB: ctx.Database, } testCases := make([]tests.Request, 0) diff --git a/internal/api/notification.go b/internal/api/notification.go index 53ca239..d8ba341 100644 --- a/internal/api/notification.go +++ b/internal/api/notification.go @@ -44,6 +44,10 @@ type NotificationHandler struct { // @Router /message [post] func (h *NotificationHandler) CreateNotification(ctx *gin.Context) { application := authentication.GetApplication(ctx) + if application == nil { + return + } + log.L.Printf("Sending notification for application %s.", application.Name) var notification model.Notification @@ -78,6 +82,10 @@ func (h *NotificationHandler) CreateNotification(ctx *gin.Context) { // @Router /message/{message_id} [DELETE] func (h *NotificationHandler) DeleteNotification(ctx *gin.Context) { application := authentication.GetApplication(ctx) + if application == nil { + return + } + log.L.Printf("Deleting notification for application %s.", application.Name) id, err := getMessageID(ctx) diff --git a/internal/api/notification_test.go b/internal/api/notification_test.go index 06e29b1..71f4182 100644 --- a/internal/api/notification_test.go +++ b/internal/api/notification_test.go @@ -5,7 +5,6 @@ import ( "io" "testing" - "github.com/gin-gonic/gin" "github.com/pushbits/server/internal/model" "github.com/pushbits/server/tests" "github.com/stretchr/testify/assert" @@ -13,9 +12,10 @@ import ( ) func TestApi_CreateNotification(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) - gin.SetMode(gin.TestMode) testApplication := model.Application{ ID: 1, @@ -40,7 +40,7 @@ func TestApi_CreateNotification(t *testing.T) { } c.Set("app", &testApplication) - TestNotificationHandler.CreateNotification(c) + ctx.NotificationHandler.CreateNotification(c) // Parse body only for successful requests if req.ShouldStatus >= 200 && req.ShouldStatus < 300 { @@ -64,8 +64,9 @@ func TestApi_CreateNotification(t *testing.T) { } func TestApi_DeleteNotification(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) - gin.SetMode(gin.TestMode) testApplication := model.Application{ ID: 1, @@ -88,7 +89,7 @@ func TestApi_DeleteNotification(t *testing.T) { c.Set("app", &testApplication) c.Set("messageid", id) - TestNotificationHandler.DeleteNotification(c) + ctx.NotificationHandler.DeleteNotification(c) assert.Equalf(w.Code, req.ShouldStatus, "(Test case: \"%s\") Expected status code %v but have %v.", req.Name, req.ShouldStatus, w.Code) } diff --git a/internal/api/user.go b/internal/api/user.go index 009a4f4..b9e9a7a 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -38,6 +38,10 @@ func (h *UserHandler) requireMultipleAdmins(ctx *gin.Context) error { } func (h *UserHandler) deleteApplications(ctx *gin.Context, u *model.User) error { + if ctx == nil || u == nil { + return errors.New("nil parameters provided") + } + applications, err := h.DB.GetApplications(u) if success := SuccessOrAbort(ctx, http.StatusInternalServerError, err); !success { return err @@ -55,6 +59,10 @@ func (h *UserHandler) deleteApplications(ctx *gin.Context, u *model.User) error } func (h *UserHandler) updateChannels(ctx *gin.Context, u *model.User, matrixID string) error { + if ctx == nil || u == nil { + return errors.New("nil parameters provided") + } + applications, err := h.DB.GetApplications(u) if success := SuccessOrAbort(ctx, http.StatusInternalServerError, err); !success { return err @@ -83,15 +91,7 @@ func (h *UserHandler) updateChannels(ctx *gin.Context, u *model.User, matrixID s return nil } -func (h *UserHandler) updateUser(ctx *gin.Context, u *model.User, updateUser model.UpdateUser) error { - if updateUser.MatrixID != nil && u.MatrixID != *updateUser.MatrixID { - if err := h.updateChannels(ctx, u, *updateUser.MatrixID); err != nil { - return err - } - } - - log.L.Printf("Updating user %s.", u.Name) - +func (h *UserHandler) updateUserFields(ctx *gin.Context, u *model.User, updateUser model.UpdateUser) error { if updateUser.Name != nil { u.Name = *updateUser.Name } @@ -100,7 +100,6 @@ func (h *UserHandler) updateUser(ctx *gin.Context, u *model.User, updateUser mod if success := SuccessOrAbort(ctx, http.StatusBadRequest, err); !success { return err } - u.PasswordHash = hash } if updateUser.MatrixID != nil { @@ -109,6 +108,25 @@ func (h *UserHandler) updateUser(ctx *gin.Context, u *model.User, updateUser mod if updateUser.IsAdmin != nil { u.IsAdmin = *updateUser.IsAdmin } + return nil +} + +func (h *UserHandler) updateUser(ctx *gin.Context, u *model.User, updateUser model.UpdateUser) error { + if u == nil { + return errors.New("nil parameters provided") + } + + if updateUser.MatrixID != nil && u.MatrixID != *updateUser.MatrixID { + if err := h.updateChannels(ctx, u, *updateUser.MatrixID); err != nil { + return err + } + } + + log.L.Printf("Updating user %s.", u.Name) + + if err := h.updateUserFields(ctx, u, updateUser); err != nil { + return err + } err := h.DB.UpdateUser(u) if success := SuccessOrAbort(ctx, http.StatusInternalServerError, err); !success { @@ -149,10 +167,12 @@ func (h *UserHandler) CreateUser(ctx *gin.Context) { log.L.Printf("Creating user %s.", createUser.Name) user, err := h.DB.CreateUser(createUser) - if success := SuccessOrAbort(ctx, http.StatusInternalServerError, err); !success { return } + if user == nil { + return + } ctx.JSON(http.StatusOK, user.IntoExternalUser()) } @@ -170,6 +190,10 @@ func (h *UserHandler) CreateUser(ctx *gin.Context) { // @Security BasicAuth // @Router /user [get] func (h *UserHandler) GetUsers(ctx *gin.Context) { + if ctx == nil { + return + } + users, err := h.DB.GetUsers() if success := SuccessOrAbort(ctx, http.StatusInternalServerError, err); !success { return @@ -199,7 +223,7 @@ func (h *UserHandler) GetUsers(ctx *gin.Context) { // @Router /user/{id} [get] func (h *UserHandler) GetUser(ctx *gin.Context) { user, err := getUser(ctx, h.DB) - if err != nil { + if err != nil || user == nil { return } @@ -221,7 +245,7 @@ func (h *UserHandler) GetUser(ctx *gin.Context) { // @Router /user/{id} [delete] func (h *UserHandler) DeleteUser(ctx *gin.Context) { user, err := getUser(ctx, h.DB) - if err != nil { + if err != nil || user == nil { return } @@ -265,7 +289,7 @@ func (h *UserHandler) DeleteUser(ctx *gin.Context) { // @Router /user/{id} [put] func (h *UserHandler) UpdateUser(ctx *gin.Context) { user, err := getUser(ctx, h.DB) - if err != nil { + if err != nil || user == nil { return } @@ -275,6 +299,9 @@ func (h *UserHandler) UpdateUser(ctx *gin.Context) { } requestingUser := authentication.GetUser(ctx) + if requestingUser == nil { + return + } // Last privileged user must not be taken privileges. Assumes that the current user has privileges. if user.ID == requestingUser.ID && updateUser.IsAdmin != nil && !(*updateUser.IsAdmin) { diff --git a/internal/api/user_test.go b/internal/api/user_test.go index 4f06738..9d65690 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -13,17 +13,19 @@ import ( ) func TestApi_CreateUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) testCases := make([]tests.Request, 0) // Add all test users - for _, user := range TestUsers { + for _, user := range ctx.Users { createUser := &model.CreateUser{} createUser.ExternalUser.Name = user.Name createUser.ExternalUser.MatrixID = "@" + user.Name + ":matrix.org" createUser.ExternalUser.IsAdmin = user.IsAdmin - createUser.UserCredentials.Password = TestConfig.Admin.Password + createUser.UserCredentials.Password = ctx.Config.Admin.Password testCase := tests.Request{ Name: "Already existing user " + user.Name, @@ -48,13 +50,15 @@ func TestApi_CreateUser(t *testing.T) { t.Fatal(err.Error()) } - TestUserHandler.CreateUser(c) + ctx.UserHandler.CreateUser(c) assert.Equalf(w.Code, req.ShouldStatus, "(Test case: \"%s\") Expected status code %v but have %v.", req.Name, req.ShouldStatus, w.Code) } } func TestApi_GetUsers(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) @@ -65,10 +69,10 @@ func TestApi_GetUsers(t *testing.T) { w, c, err := request.GetRequest() if err != nil { - t.Fatalf((err.Error())) + t.Fatalf("error getting request: %v", err) } - TestUserHandler.GetUsers(c) + ctx.UserHandler.GetUsers(c) assert.Equalf(w.Code, 200, "Response code should be 200 but is %d", w.Code) // Get users from body @@ -79,7 +83,7 @@ func TestApi_GetUsers(t *testing.T) { require.NoErrorf(err, "Can not unmarshal users") // Check existence of all known users - for _, user := range TestUsers { + for _, user := range ctx.Users { found := false for _, userExt := range users { if user.ID == userExt.ID && user.Name == userExt.Name { @@ -92,13 +96,15 @@ func TestApi_GetUsers(t *testing.T) { } func TestApi_UpdateUser(t *testing.T) { - assert := assert.New(t) - admin := getAdmin() + ctx := GetTestContext(t) + assert := assert.New(t) + + admin := getAdmin(ctx) testCases := make(map[uint]tests.Request) // Add all test users - for _, user := range TestUsers { + for _, user := range ctx.Users { updateUser := &model.UpdateUser{} user.Name += "+1" user.IsAdmin = !user.IsAdmin @@ -125,7 +131,7 @@ func TestApi_UpdateUser(t *testing.T) { c.Set("id", id) c.Set("user", admin) - TestUserHandler.UpdateUser(c) + ctx.UserHandler.UpdateUser(c) assert.Equalf(w.Code, req.ShouldStatus, "(Test case: \"%s\") Expected status code %v but have %v.", req.Name, req.ShouldStatus, w.Code) } @@ -138,11 +144,13 @@ func TestApi_UpdateUser(t *testing.T) { } c.Set("id", id) - assert.Panicsf(func() { TestUserHandler.UpdateUser(c) }, "User not set should panic but did not") + assert.Panicsf(func() { ctx.UserHandler.UpdateUser(c) }, "User not set should panic but did not") } } func TestApi_GetUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) @@ -151,7 +159,7 @@ func TestApi_GetUser(t *testing.T) { testCases[uint(9999999)] = tests.Request{Name: "Unknown id", Method: "GET", Endpoint: "/user/99999999", ShouldStatus: 404} // Check if we can get all existing users - for _, user := range TestUsers { + for _, user := range ctx.Users { testCases[user.ID] = tests.Request{ Name: "Valid user " + user.Name, Method: "GET", @@ -166,7 +174,7 @@ func TestApi_GetUser(t *testing.T) { require.NoErrorf(err, "(Test case %s) Could not make request", testCase.Name) c.Set("id", id) - TestUserHandler.GetUser(c) + ctx.UserHandler.GetUser(c) assert.Equalf(testCase.ShouldStatus, w.Code, "(Test case %s) Expected status code %d but have %d", testCase.Name, testCase.ShouldStatus, w.Code) @@ -191,13 +199,16 @@ func TestApi_GetUser(t *testing.T) { } func TestApi_DeleteUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) + testCases := make(map[interface{}]tests.Request) testCases["abcde"] = tests.Request{Name: "Invalid user - string", Method: "DELETE", Endpoint: "/user/abcde", ShouldStatus: 500} testCases[uint(999999)] = tests.Request{Name: "Unknown user", Method: "DELETE", Endpoint: "/user/999999", ShouldStatus: 404} - for _, user := range TestUsers { + for _, user := range ctx.Users { shouldStatus := 200 testCases[user.ID] = tests.Request{ Name: "Valid user " + user.Name, @@ -212,14 +223,14 @@ func TestApi_DeleteUser(t *testing.T) { require.NoErrorf(err, "(Test case %s) Could not make request", testCase.Name) c.Set("id", id) - TestUserHandler.DeleteUser(c) + ctx.UserHandler.DeleteUser(c) assert.Equalf(testCase.ShouldStatus, w.Code, "(Test case %s) Expected status code %d but have %d", testCase.Name, testCase.ShouldStatus, w.Code) } } -func getAdmin() *model.User { - for _, user := range TestUsers { +func getAdmin(ctx *TestContext) *model.User { + for _, user := range ctx.Users { if user.IsAdmin { return user } diff --git a/internal/api/util_test.go b/internal/api/util_test.go index 2870575..079c0da 100644 --- a/internal/api/util_test.go +++ b/internal/api/util_test.go @@ -11,6 +11,8 @@ import ( ) func TestApi_SuccessOrAbort(t *testing.T) { + GetTestContext(t) + assert := assert.New(t) require := require.New(t) @@ -37,10 +39,12 @@ func TestApi_SuccessOrAbort(t *testing.T) { } func TestApi_IsCurrentUser(t *testing.T) { + ctx := GetTestContext(t) + assert := assert.New(t) require := require.New(t) - for _, user := range TestUsers { + for _, user := range ctx.Users { testCases := make(map[uint]tests.Request) testCases[user.ID] = tests.Request{Name: fmt.Sprintf("User %s - success", user.Name), Endpoint: "/", ShouldStatus: 200} diff --git a/internal/authentication/credentials/hibp.go b/internal/authentication/credentials/hibp.go index 6e05a8c..c7f03d7 100644 --- a/internal/authentication/credentials/hibp.go +++ b/internal/authentication/credentials/hibp.go @@ -33,6 +33,9 @@ func IsPasswordPwned(password string) (bool, error) { if err != nil { return false, err } + if resp == nil { + return false, fmt.Errorf("received nil response from http request") + } if resp.StatusCode != http.StatusOK { log.L.Fatalf("Request failed with HTTP %s.", resp.Status)