test: add upgrade tests (#1126)

This commit is contained in:
Jonas Plum
2025-02-02 13:40:33 +01:00
committed by GitHub
parent b31f90c3ea
commit 7de89a752c
9 changed files with 411 additions and 44 deletions

3
.gitignore vendored
View File

@@ -35,4 +35,7 @@ pb_data
catalyst
catalyst_data
# ignore changes, needs to be disabled when adding new upgrade tests
upgradetest
coverage.out

View File

@@ -71,6 +71,11 @@ dev-10000:
go run . fake-data --users 100 --tickets 10000
go run . serve --app-url http://localhost:8090 --flags dev
.PHONY: default-data
default-data:
rm -rf catalyst_data
go run . default-data
.PHONY: serve-ui
serve-ui:
cd ui && bun dev --port 3000

View File

@@ -33,6 +33,7 @@ func App(dir string, test bool) (*pocketbase.PocketBase, error) {
_ = app.RootCmd.ParseFlags(os.Args[1:])
app.RootCmd.AddCommand(fakeDataCmd(app))
app.RootCmd.AddCommand(defaultDataCmd(app))
webhook.BindHooks(app)
reaction.BindHooks(app, test)

View File

@@ -23,3 +23,14 @@ func fakeDataCmd(app core.App) *cobra.Command {
return cmd
}
func defaultDataCmd(app core.App) *cobra.Command {
cmd := &cobra.Command{
Use: "default-data",
RunE: func(_ *cobra.Command, _ []string) error {
return fakedata.GenerateDefaultData(app)
},
}
return cmd
}

235
fakedata/default.go Normal file
View File

@@ -0,0 +1,235 @@
package fakedata
import (
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/SecurityBrewery/catalyst/migrations"
)
func defaultData() map[string]map[string]map[string]any {
var (
ticketCreated = time.Date(2025, 2, 1, 11, 29, 35, 0, time.UTC)
ticketUpdated = ticketCreated.Add(time.Minute * 5)
commentCreated = ticketCreated.Add(time.Minute * 10)
commentUpdated = commentCreated.Add(time.Minute * 5)
timelineCreated = ticketCreated.Add(time.Minute * 15)
timelineUpdated = timelineCreated.Add(time.Minute * 5)
taskCreated = ticketCreated.Add(time.Minute * 20)
taskUpdated = taskCreated.Add(time.Minute * 5)
linkCreated = ticketCreated.Add(time.Minute * 25)
linkUpdated = linkCreated.Add(time.Minute * 5)
reactionCreated = time.Date(2025, 2, 1, 11, 30, 0, 0, time.UTC)
reactionUpdated = reactionCreated.Add(time.Minute * 5)
)
createTicketActionData := `{"requirements":"pocketbase","script":"import sys\nimport json\nimport random\nimport os\n\nfrom pocketbase import PocketBase\n\n# Connect to the PocketBase server\nclient = PocketBase(os.environ[\"CATALYST_APP_URL\"])\nclient.auth_store.save(token=os.environ[\"CATALYST_TOKEN\"])\n\nnewtickets = client.collection(\"tickets\").get_list(1, 200, {\"filter\": 'name = \"New Ticket\"'})\nfor ticket in newtickets.items:\n\tclient.collection(\"tickets\").delete(ticket.id)\n\n# Create a new ticket\nclient.collection(\"tickets\").create({\n\t\"name\": \"New Ticket\",\n\t\"type\": \"alert\",\n\t\"open\": True,\n})"}`
return map[string]map[string]map[string]any{
migrations.TicketCollectionName: {
"t_0": {
"created": dateTime(ticketCreated),
"updated": dateTime(ticketUpdated),
"name": "phishing-123",
"type": "alert",
"description": "Phishing email reported by several employees.",
"open": true,
"schema": types.JsonRaw(`{"type":"object","properties":{"tlp":{"title":"TLP","type":"string"}}}`),
"state": types.JsonRaw(`{"severity":"Medium"}`),
"owner": "u_test",
},
},
migrations.CommentCollectionName: {
"c_0": {
"created": dateTime(commentCreated),
"updated": dateTime(commentUpdated),
"ticket": "t_0",
"author": "u_test",
"message": "This is a test comment.",
},
},
migrations.TimelineCollectionName: {
"tl_0": {
"created": dateTime(timelineCreated),
"updated": dateTime(timelineUpdated),
"ticket": "t_0",
"time": dateTime(timelineCreated),
"message": "This is a test timeline message.",
},
},
migrations.TaskCollectionName: {
"ts_0": {
"created": dateTime(taskCreated),
"updated": dateTime(taskUpdated),
"ticket": "t_0",
"name": "This is a test task.",
"open": true,
"owner": "u_test",
},
},
migrations.LinkCollectionName: {
"l_0": {
"created": dateTime(linkCreated),
"updated": dateTime(linkUpdated),
"ticket": "t_0",
"url": "https://www.example.com",
"name": "This is a test link.",
},
},
migrations.ReactionCollectionName: {
"w_0": {
"created": dateTime(reactionCreated),
"updated": dateTime(reactionUpdated),
"name": "Create New Ticket",
"trigger": "schedule",
"triggerdata": types.JsonRaw(triggerSchedule),
"action": "python",
"actiondata": types.JsonRaw(createTicketActionData),
},
},
}
}
func GenerateDefaultData(app core.App) error {
var records []*models.Record
// users
userRecord, err := testUser(app.Dao())
if err != nil {
return err
}
records = append(records, userRecord)
// records
for collectionName, collectionRecords := range defaultData() {
collection, err := app.Dao().FindCollectionByNameOrId(collectionName)
if err != nil {
return err
}
for id, fields := range collectionRecords {
record := models.NewRecord(collection)
record.SetId(id)
for key, value := range fields {
record.Set(key, value)
}
records = append(records, record)
}
}
for _, record := range records {
if err := app.Dao().SaveRecord(record); err != nil {
return err
}
}
return nil
}
func ValidateDefaultData(app core.App) error { //nolint:cyclop,gocognit
// users
userRecord, err := app.Dao().FindRecordById(migrations.UserCollectionName, "u_test")
if err != nil {
return fmt.Errorf("failed to find user record: %w", err)
}
if userRecord == nil {
return errors.New("user not found")
}
if userRecord.Username() != "u_test" {
return fmt.Errorf(`username does not match: got %q, want "u_test"`, userRecord.Username())
}
if !userRecord.ValidatePassword("1234567890") {
return errors.New("password does not match")
}
if userRecord.Get("name") != "Test User" {
return fmt.Errorf(`name does not match: got %q, want "Test User"`, userRecord.Get("name"))
}
if userRecord.Get("email") != "user@catalyst-soar.com" {
return fmt.Errorf(`email does not match: got %q, want "user@catalyst-soar.com"`, userRecord.Get("email"))
}
if !userRecord.Verified() {
return errors.New("user is not verified")
}
// records
for collectionName, collectionRecords := range defaultData() {
for id, fields := range collectionRecords {
record, err := app.Dao().FindRecordById(collectionName, id)
if err != nil {
return fmt.Errorf("failed to find record %s: %w", id, err)
}
if record == nil {
return errors.New("record not found")
}
for key, value := range fields {
got := record.Get(key)
if wantJSON, ok := value.(types.JsonRaw); ok {
if err := compareJSON(got, wantJSON); err != nil {
return fmt.Errorf("record field %q does not match: %w", key, err)
}
continue
}
if got != value {
return fmt.Errorf("record field %s does not match: got %v (%T), want %v (%T)", key, got, got, value, value)
}
}
}
}
return nil
}
func compareJSON(got any, wantJSON types.JsonRaw) error {
gotJSON, ok := got.(types.JsonRaw)
if !ok {
return fmt.Errorf("got %T, want %T", got, wantJSON)
}
if !jsonEqual(gotJSON.String(), wantJSON.String()) {
return fmt.Errorf("got %v, want %v", gotJSON, wantJSON)
}
return nil
}
func jsonEqual(a, b string) bool {
var objA, objB interface{}
if err := json.Unmarshal([]byte(a), &objA); err != nil {
return false
}
if err := json.Unmarshal([]byte(b), &objB); err != nil {
return false
}
return reflect.DeepEqual(objA, objB)
}
func dateTime(t time.Time) types.DateTime {
dt := types.DateTime{}
_ = dt.Scan(t)
return dt
}

21
fakedata/default_test.go Normal file
View File

@@ -0,0 +1,21 @@
package fakedata_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/SecurityBrewery/catalyst/fakedata"
catalystTesting "github.com/SecurityBrewery/catalyst/testing"
)
func TestDefaultData(t *testing.T) {
t.Parallel()
app, _, cleanup := catalystTesting.App(t)
defer cleanup()
require.NoError(t, fakedata.GenerateDefaultData(app))
require.NoError(t, fakedata.ValidateDefaultData(app))
}

View File

@@ -49,38 +49,45 @@ func Records(app core.App, userCount int, ticketCount int) ([]*models.Record, er
return nil, err
}
users := userRecords(app.Dao(), userCount)
tickets := ticketRecords(app.Dao(), users, types, ticketCount)
reactions := reactionRecords(app.Dao())
users, err := userRecords(app.Dao(), userCount)
if err != nil {
return nil, err
}
tickets, err := ticketRecords(app.Dao(), users, types, ticketCount)
if err != nil {
return nil, err
}
reactions, err := reactionRecords(app.Dao())
if err != nil {
return nil, err
}
var records []*models.Record
records = append(records, users...)
records = append(records, types...)
records = append(records, tickets...)
records = append(records, reactions...)
return records, nil
}
func userRecords(dao *daos.Dao, count int) []*models.Record {
collection, err := dao.FindCollectionByNameOrId(migrations.UserCollectionName)
if err != nil {
panic(err)
}
func userRecords(dao *daos.Dao, count int) ([]*models.Record, error) {
records := make([]*models.Record, 0, count)
// create the test user
if _, err := dao.FindRecordById(migrations.UserCollectionName, "u_test"); err != nil {
record := models.NewRecord(collection)
record.SetId("u_test")
_ = record.SetUsername("u_test")
_ = record.SetPassword("1234567890")
record.Set("name", gofakeit.Name())
record.Set("email", "user@catalyst-soar.com")
_ = record.SetVerified(true)
testUser, err := testUser(dao)
if err != nil {
return nil, err
}
records = append(records, record)
records = append(records, testUser)
}
collection, err := dao.FindCollectionByNameOrId(migrations.UserCollectionName)
if err != nil {
return nil, err
}
for range count - 1 {
@@ -95,13 +102,30 @@ func userRecords(dao *daos.Dao, count int) []*models.Record {
records = append(records, record)
}
return records
return records, nil
}
func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*models.Record {
func testUser(dao *daos.Dao) (*models.Record, error) {
collection, err := dao.FindCollectionByNameOrId(migrations.UserCollectionName)
if err != nil {
return nil, err
}
record := models.NewRecord(collection)
record.SetId("u_test")
_ = record.SetUsername("u_test")
_ = record.SetPassword("1234567890")
record.Set("name", "Test User")
record.Set("email", "user@catalyst-soar.com")
_ = record.SetVerified(true)
return record, nil
}
func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) ([]*models.Record, error) {
collection, err := dao.FindCollectionByNameOrId(migrations.TicketCollectionName)
if err != nil {
panic(err)
return nil, err
}
records := make([]*models.Record, 0, count)
@@ -134,19 +158,42 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
records = append(records, record)
// Add comments
records = append(records, commentRecords(dao, users, created, record)...)
records = append(records, timelineRecords(dao, created, record)...)
records = append(records, taskRecords(dao, users, created, record)...)
records = append(records, linkRecords(dao, created, record)...)
comments, err := commentRecords(dao, users, created, record)
if err != nil {
return nil, err
}
return records
records = append(records, comments...)
timelines, err := timelineRecords(dao, created, record)
if err != nil {
return nil, err
}
func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record {
records = append(records, timelines...)
tasks, err := taskRecords(dao, users, created, record)
if err != nil {
return nil, err
}
records = append(records, tasks...)
links, err := linkRecords(dao, created, record)
if err != nil {
return nil, err
}
records = append(records, links...)
}
return records, nil
}
func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) ([]*models.Record, error) {
commentCollection, err := dao.FindCollectionByNameOrId(migrations.CommentCollectionName)
if err != nil {
panic(err)
return nil, err
}
records := make([]*models.Record, 0, 5)
@@ -166,13 +213,13 @@ func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, re
records = append(records, commentRecord)
}
return records
return records, nil
}
func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record {
func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) ([]*models.Record, error) {
timelineCollection, err := dao.FindCollectionByNameOrId(migrations.TimelineCollectionName)
if err != nil {
panic(err)
return nil, err
}
records := make([]*models.Record, 0, 5)
@@ -192,13 +239,13 @@ func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) []
records = append(records, timelineRecord)
}
return records
return records, nil
}
func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record {
func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) ([]*models.Record, error) {
taskCollection, err := dao.FindCollectionByNameOrId(migrations.TaskCollectionName)
if err != nil {
panic(err)
return nil, err
}
records := make([]*models.Record, 0, 5)
@@ -219,13 +266,13 @@ func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, recor
records = append(records, taskRecord)
}
return records
return records, nil
}
func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record {
func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) ([]*models.Record, error) {
linkCollection, err := dao.FindCollectionByNameOrId(migrations.LinkCollectionName)
if err != nil {
panic(err)
return nil, err
}
records := make([]*models.Record, 0, 5)
@@ -245,7 +292,7 @@ func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) []*mod
records = append(records, linkRecord)
}
return records
return records, nil
}
const createTicketPy = `import sys
@@ -321,12 +368,12 @@ const (
triggerHook = `{"collections":["tickets"],"events":["create"]}`
)
func reactionRecords(dao *daos.Dao) []*models.Record {
func reactionRecords(dao *daos.Dao) ([]*models.Record, error) {
var records []*models.Record
collection, err := dao.FindCollectionByNameOrId(migrations.ReactionCollectionName)
if err != nil {
panic(err)
return nil, err
}
createTicketActionData, err := json.Marshal(map[string]interface{}{
@@ -334,7 +381,7 @@ func reactionRecords(dao *daos.Dao) []*models.Record {
"script": createTicketPy,
})
if err != nil {
panic(err)
return nil, err
}
record := models.NewRecord(collection)
@@ -352,7 +399,7 @@ func reactionRecords(dao *daos.Dao) []*models.Record {
"script": alertIngestPy,
})
if err != nil {
panic(err)
return nil, err
}
record = models.NewRecord(collection)
@@ -370,7 +417,7 @@ func reactionRecords(dao *daos.Dao) []*models.Record {
"script": assignTicketsPy,
})
if err != nil {
panic(err)
return nil, err
}
record = models.NewRecord(collection)
@@ -383,5 +430,5 @@ func reactionRecords(dao *daos.Dao) []*models.Record {
records = append(records, record)
return records
return records, nil
}

Binary file not shown.

View File

@@ -0,0 +1,44 @@
package upgradetest
import (
"fmt"
"log"
"os"
"path/filepath"
"testing"
"github.com/SecurityBrewery/catalyst/app"
"github.com/SecurityBrewery/catalyst/fakedata"
)
func TestUpgrades(t *testing.T) {
t.Parallel()
dirEntries, err := os.ReadDir("data")
if err != nil {
t.Fatal(err)
}
for _, entry := range dirEntries {
if !entry.IsDir() {
continue
}
t.Run(entry.Name(), func(t *testing.T) {
t.Parallel()
pb, err := app.App(filepath.Join("data", entry.Name()), true)
if err != nil {
log.Fatal(err)
}
if err := pb.Bootstrap(); err != nil {
t.Fatal(fmt.Errorf("failed to bootstrap: %w", err))
}
if err := fakedata.ValidateDefaultData(pb); err != nil {
log.Fatal(err)
}
})
}
}