Add support for configuration of TLS

This commit is contained in:
eikendev 2023-07-15 23:25:34 +02:00
parent 833e666c37
commit 61d5e04ecf
No known key found for this signature in database
GPG key ID: A1BDB1B28C8EF694
8 changed files with 65 additions and 12 deletions

View file

@ -88,7 +88,7 @@ func main() {
log.L.Fatal(err) log.L.Fatal(err)
} }
err = runner.Run(engine, c.HTTP.ListenAddress, c.HTTP.Port) err = runner.Run(engine, c)
if err != nil { if err != nil {
log.L.Fatal(err) log.L.Fatal(err)
} }

View file

@ -16,6 +16,12 @@ http:
# What proxies to trust. # What proxies to trust.
trustedproxies: [] trustedproxies: []
# Filename of the TLS certificate.
certfile: ''
# Filename of the TLS private key.
keyfile: ''
database: database:
# Currently sqlite3, mysql, and postgres are supported. # Currently sqlite3, mysql, and postgres are supported.
dialect: 'sqlite3' dialect: 'sqlite3'

View file

@ -15,7 +15,7 @@ func SuccessOrAbort(ctx *gin.Context, code int, err error) bool {
if err != nil { if err != nil {
// If we know the error force error code // If we know the error force error code
switch err { switch err {
case pberrors.ErrorMessageNotFound: case pberrors.ErrMessageNotFound:
ctx.AbortWithError(http.StatusNotFound, err) ctx.AbortWithError(http.StatusNotFound, err)
default: default:
ctx.AbortWithError(code, err) ctx.AbortWithError(code, err)

View file

@ -3,6 +3,8 @@ package configuration
import ( import (
"github.com/jinzhu/configor" "github.com/jinzhu/configor"
"github.com/pushbits/server/internal/log"
"github.com/pushbits/server/internal/pberrors"
) )
// testMode indicates if the package is run in test mode // testMode indicates if the package is run in test mode
@ -53,6 +55,8 @@ type Configuration struct {
ListenAddress string `default:""` ListenAddress string `default:""`
Port int `default:"8080"` Port int `default:"8080"`
TrustedProxies []string `default:"[]"` TrustedProxies []string `default:"[]"`
CertFile string `default:""`
KeyFile string `default:""`
} }
Database struct { Database struct {
Dialect string `default:"sqlite3"` Dialect string `default:"sqlite3"`
@ -80,6 +84,21 @@ func configFiles() []string {
return []string{"config.yml"} return []string{"config.yml"}
} }
func validateHTTPConfiguration(c *Configuration) error {
certAndKeyEmpty := (c.HTTP.CertFile == "" && c.HTTP.KeyFile == "")
certAndKeyPopulated := (c.HTTP.CertFile != "" && c.HTTP.KeyFile != "")
if !certAndKeyEmpty && !certAndKeyPopulated {
return pberrors.ErrConfigTLSFilesInconsistent
}
return nil
}
func validateConfiguration(c *Configuration) error {
return validateHTTPConfiguration(c)
}
// Get returns the configuration extracted from env variables or config file. // Get returns the configuration extracted from env variables or config file.
func Get() *Configuration { func Get() *Configuration {
config := &Configuration{} config := &Configuration{}
@ -93,5 +112,9 @@ func Get() *Configuration {
panic(err) panic(err)
} }
if err := validateConfiguration(config); err != nil {
log.L.Fatal(err)
}
return config return config
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/jinzhu/configor" "github.com/jinzhu/configor"
"github.com/pushbits/server/internal/log" "github.com/pushbits/server/internal/log"
"github.com/pushbits/server/internal/pberrors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -231,3 +232,18 @@ func cleanUp() {
log.L.Warnln("Cannot remove config file: ", err) log.L.Warnln("Cannot remove config file: ", err)
} }
} }
func TestConfigurationValidation_ConfigTLSFilesInconsistent(t *testing.T) {
assert := assert.New(t)
c := Configuration{}
c.Admin.MatrixID = "000000"
c.Matrix.Username = "default-username"
c.Matrix.Password = "default-password"
c.HTTP.CertFile = "populated"
c.HTTP.KeyFile = ""
is := validateConfiguration(&c)
should := pberrors.ErrConfigTLSFilesInconsistent
assert.Equal(is, should, "validateConfiguration() should return ConfigTLSFilesInconsistent")
}

View file

@ -90,7 +90,7 @@ func (d *Dispatcher) DeleteNotification(a *model.Application, n *model.DeleteNot
deleteMessage, err := d.getMessage(a, n.ID) deleteMessage, err := d.getMessage(a, n.ID)
if err != nil { if err != nil {
log.L.Println(err) log.L.Println(err)
return pberrors.ErrorMessageNotFound return pberrors.ErrMessageNotFound
} }
oldBody, oldFormattedBody, err = bodiesFromMessage(deleteMessage) oldBody, oldFormattedBody, err = bodiesFromMessage(deleteMessage)
@ -199,7 +199,7 @@ func (d *Dispatcher) getMessage(a *model.Application, id string) (*event.Event,
start = messages.End start = messages.End
} }
return nil, pberrors.ErrorMessageNotFound return nil, pberrors.ErrMessageNotFound
} }
// Replaces the content of a matrix message // Replaces the content of a matrix message
@ -273,7 +273,7 @@ func (d *Dispatcher) respondToMessage(a *model.Application, body, formattedBody
func bodiesFromMessage(message *event.Event) (body, formattedBody string, err error) { func bodiesFromMessage(message *event.Event) (body, formattedBody string, err error) {
msgContent := message.Content.AsMessage() msgContent := message.Content.AsMessage()
if msgContent == nil { if msgContent == nil {
return "", "", pberrors.ErrorMessageNotFound return "", "", pberrors.ErrMessageNotFound
} }
formattedBody = msgContent.Body formattedBody = msgContent.Body

View file

@ -3,5 +3,8 @@ package pberrors
import "errors" import "errors"
// ErrorMessageNotFound indicates that a message does not exist // ErrMessageNotFound indicates that a message does not exist
var ErrorMessageNotFound = errors.New("message not found") var ErrMessageNotFound = errors.New("message not found")
// ErrConfigTLSFilesInconsistent indicates that either just a certfile or a keyfile was provided
var ErrConfigTLSFilesInconsistent = errors.New("TLS certfile and keyfile must either both be provided or omitted")

View file

@ -5,14 +5,19 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pushbits/server/internal/configuration"
) )
// Run starts the Gin engine. // Run starts the Gin engine.
func Run(engine *gin.Engine, address string, port int) error { func Run(engine *gin.Engine, c *configuration.Configuration) error {
err := engine.Run(fmt.Sprintf("%s:%d", address, port)) var err error
if err != nil { address := fmt.Sprintf("%s:%d", c.HTTP.ListenAddress, c.HTTP.Port)
return err
if c.HTTP.CertFile != "" && c.HTTP.KeyFile != "" {
err = engine.RunTLS(address, c.HTTP.CertFile, c.HTTP.KeyFile)
} else {
err = engine.Run(address)
} }
return nil return err
} }