mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
210 lines
4.7 KiB
Go
210 lines
4.7 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
)
|
|
|
|
var migrations = []func(context.Context, pgx.Tx) error{
|
|
1: func(ctx context.Context, tx pgx.Tx) error {
|
|
_, err := tx.Exec(ctx, `
|
|
CREATE TABLE `+schemaName+`.`+recordsTableName+` (
|
|
type TEXT NOT NULL,
|
|
id TEXT NOT NULL,
|
|
version BIGINT NOT NULL,
|
|
data JSONB NOT NULL,
|
|
modified_at TIMESTAMPTZ NOT NULL DEFAULT(NOW()),
|
|
|
|
index_cidr INET NULL,
|
|
|
|
PRIMARY KEY (type, id)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
CREATE INDEX ON `+schemaName+`.`+recordsTableName+`
|
|
USING gist (index_cidr inet_ops);
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
CREATE TABLE `+schemaName+`.`+recordChangesTableName+` (
|
|
type TEXT NOT NULL,
|
|
id TEXT NOT NULL,
|
|
version BIGINT NOT NULL,
|
|
data JSONB NOT NULL,
|
|
modified_at TIMESTAMPTZ NOT NULL,
|
|
deleted_at TIMESTAMPTZ NULL,
|
|
|
|
PRIMARY KEY (version)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
CREATE TABLE `+schemaName+`.`+recordOptionsTableName+` (
|
|
type TEXT NOT NULL,
|
|
capacity BIGINT NULL,
|
|
|
|
PRIMARY KEY (type)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
CREATE TABLE `+schemaName+`.`+leasesTableName+` (
|
|
name TEXT NOT NULL,
|
|
id TEXT NOT NULL,
|
|
expires_at TIMESTAMPTZ NOT NULL,
|
|
|
|
PRIMARY KEY (name)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
},
|
|
2: func(ctx context.Context, tx pgx.Tx) error {
|
|
serverVersion := uint64(cryptutil.NewRandomUInt32())
|
|
_, err := tx.Exec(ctx, `
|
|
UPDATE `+schemaName+`.`+migrationInfoTableName+`
|
|
SET server_version = $1
|
|
`, serverVersion)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
DELETE FROM `+schemaName+`.`+recordChangesTableName+`
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
DELETE FROM `+schemaName+`.`+recordsTableName+`
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
ALTER TABLE `+schemaName+`.`+recordChangesTableName+`
|
|
ALTER COLUMN version
|
|
ADD GENERATED BY DEFAULT AS IDENTITY
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
},
|
|
3: func(ctx context.Context, tx pgx.Tx) error {
|
|
_, err := tx.Exec(ctx, `
|
|
CREATE TABLE `+schemaName+`.`+servicesTableName+` (
|
|
kind TEXT NOT NULL,
|
|
endpoint TEXT NOT NULL,
|
|
expires_at TIMESTAMPTZ NOT NULL,
|
|
|
|
PRIMARY KEY (kind, endpoint)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
},
|
|
4: func(ctx context.Context, tx pgx.Tx) error {
|
|
_, err := tx.Exec(ctx, `
|
|
ALTER TABLE `+schemaName+`.`+recordChangesTableName+`
|
|
ALTER data DROP NOT NULL
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
},
|
|
5: func(ctx context.Context, tx pgx.Tx) error {
|
|
for _, q := range []string{
|
|
`CREATE INDEX ON ` + schemaName + `.` + recordsTableName + ` (type)`,
|
|
`CREATE INDEX ON ` + schemaName + `.` + recordsTableName + ` (type, version)`,
|
|
`CREATE INDEX ON ` + schemaName + `.` + recordChangesTableName + ` (modified_at)`,
|
|
`CREATE INDEX ON ` + schemaName + `.` + recordChangesTableName + ` (version)`,
|
|
`CREATE INDEX ON ` + schemaName + `.` + recordChangesTableName + ` (type)`,
|
|
`CREATE INDEX ON ` + schemaName + `.` + recordChangesTableName + ` (type, version)`,
|
|
} {
|
|
_, err := tx.Exec(ctx, q)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
},
|
|
}
|
|
|
|
func migrate(ctx context.Context, tx pgx.Tx) (serverVersion uint64, err error) {
|
|
_, err = tx.Exec(ctx, `CREATE SCHEMA IF NOT EXISTS `+schemaName)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
CREATE TABLE IF NOT EXISTS `+schemaName+`.`+migrationInfoTableName+` (
|
|
server_version BIGINT NOT NULL,
|
|
migration_version SMALLINT NOT NULL
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
|
|
var migrationVersion uint64
|
|
err = tx.QueryRow(ctx, `
|
|
SELECT server_version, migration_version
|
|
FROM `+schemaName+`.migration_info
|
|
`).Scan(&serverVersion, &migrationVersion)
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
serverVersion = uint64(cryptutil.NewRandomUInt32()) // we can't actually store a uint64, just an int64, so just generate a uint32
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO `+schemaName+`.`+migrationInfoTableName+` (server_version, migration_version)
|
|
VALUES ($1, $2)
|
|
`, serverVersion, 0)
|
|
}
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
|
|
for version := migrationVersion + 1; version < uint64(len(migrations)); version++ {
|
|
err = migrations[version](ctx, tx)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
_, err = tx.Exec(ctx, `
|
|
UPDATE `+schemaName+`.`+migrationInfoTableName+`
|
|
SET migration_version = $1
|
|
`, version)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
}
|
|
|
|
return serverVersion, nil
|
|
}
|