diff --git a/.gitignore b/.gitignore index deb0cd8..7c8bff3 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,7 @@ pb_data catalyst catalyst_data +# ignore changes, needs to be disabled when adding new upgrade tests +upgradetest + coverage.out diff --git a/Makefile b/Makefile index 739c095..36247b7 100644 --- a/Makefile +++ b/Makefile @@ -71,6 +71,11 @@ dev-10000: go run . fake-data --users 100 --tickets 10000 go run . serve --app-url http://localhost:8090 --flags dev +.PHONY: default-data +default-data: + rm -rf catalyst_data + go run . default-data + .PHONY: serve-ui serve-ui: cd ui && bun dev --port 3000 diff --git a/app/app.go b/app/app.go index cb75b82..76a04ce 100644 --- a/app/app.go +++ b/app/app.go @@ -33,6 +33,7 @@ func App(dir string, test bool) (*pocketbase.PocketBase, error) { _ = app.RootCmd.ParseFlags(os.Args[1:]) app.RootCmd.AddCommand(fakeDataCmd(app)) + app.RootCmd.AddCommand(defaultDataCmd(app)) webhook.BindHooks(app) reaction.BindHooks(app, test) diff --git a/app/fakedata.go b/app/fakedata.go index e26540c..c689789 100644 --- a/app/fakedata.go +++ b/app/fakedata.go @@ -23,3 +23,14 @@ func fakeDataCmd(app core.App) *cobra.Command { return cmd } + +func defaultDataCmd(app core.App) *cobra.Command { + cmd := &cobra.Command{ + Use: "default-data", + RunE: func(_ *cobra.Command, _ []string) error { + return fakedata.GenerateDefaultData(app) + }, + } + + return cmd +} diff --git a/fakedata/default.go b/fakedata/default.go new file mode 100644 index 0000000..679aa31 --- /dev/null +++ b/fakedata/default.go @@ -0,0 +1,235 @@ +package fakedata + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "time" + + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/models" + "github.com/pocketbase/pocketbase/tools/types" + + "github.com/SecurityBrewery/catalyst/migrations" +) + +func defaultData() map[string]map[string]map[string]any { + var ( + ticketCreated = time.Date(2025, 2, 1, 11, 29, 35, 0, time.UTC) + ticketUpdated = ticketCreated.Add(time.Minute * 5) + commentCreated = ticketCreated.Add(time.Minute * 10) + commentUpdated = commentCreated.Add(time.Minute * 5) + timelineCreated = ticketCreated.Add(time.Minute * 15) + timelineUpdated = timelineCreated.Add(time.Minute * 5) + taskCreated = ticketCreated.Add(time.Minute * 20) + taskUpdated = taskCreated.Add(time.Minute * 5) + linkCreated = ticketCreated.Add(time.Minute * 25) + linkUpdated = linkCreated.Add(time.Minute * 5) + reactionCreated = time.Date(2025, 2, 1, 11, 30, 0, 0, time.UTC) + reactionUpdated = reactionCreated.Add(time.Minute * 5) + ) + + createTicketActionData := `{"requirements":"pocketbase","script":"import sys\nimport json\nimport random\nimport os\n\nfrom pocketbase import PocketBase\n\n# Connect to the PocketBase server\nclient = PocketBase(os.environ[\"CATALYST_APP_URL\"])\nclient.auth_store.save(token=os.environ[\"CATALYST_TOKEN\"])\n\nnewtickets = client.collection(\"tickets\").get_list(1, 200, {\"filter\": 'name = \"New Ticket\"'})\nfor ticket in newtickets.items:\n\tclient.collection(\"tickets\").delete(ticket.id)\n\n# Create a new ticket\nclient.collection(\"tickets\").create({\n\t\"name\": \"New Ticket\",\n\t\"type\": \"alert\",\n\t\"open\": True,\n})"}` + + return map[string]map[string]map[string]any{ + migrations.TicketCollectionName: { + "t_0": { + "created": dateTime(ticketCreated), + "updated": dateTime(ticketUpdated), + "name": "phishing-123", + "type": "alert", + "description": "Phishing email reported by several employees.", + "open": true, + "schema": types.JsonRaw(`{"type":"object","properties":{"tlp":{"title":"TLP","type":"string"}}}`), + "state": types.JsonRaw(`{"severity":"Medium"}`), + "owner": "u_test", + }, + }, + migrations.CommentCollectionName: { + "c_0": { + "created": dateTime(commentCreated), + "updated": dateTime(commentUpdated), + "ticket": "t_0", + "author": "u_test", + "message": "This is a test comment.", + }, + }, + migrations.TimelineCollectionName: { + "tl_0": { + "created": dateTime(timelineCreated), + "updated": dateTime(timelineUpdated), + "ticket": "t_0", + "time": dateTime(timelineCreated), + "message": "This is a test timeline message.", + }, + }, + migrations.TaskCollectionName: { + "ts_0": { + "created": dateTime(taskCreated), + "updated": dateTime(taskUpdated), + "ticket": "t_0", + "name": "This is a test task.", + "open": true, + "owner": "u_test", + }, + }, + migrations.LinkCollectionName: { + "l_0": { + "created": dateTime(linkCreated), + "updated": dateTime(linkUpdated), + "ticket": "t_0", + "url": "https://www.example.com", + "name": "This is a test link.", + }, + }, + migrations.ReactionCollectionName: { + "w_0": { + "created": dateTime(reactionCreated), + "updated": dateTime(reactionUpdated), + "name": "Create New Ticket", + "trigger": "schedule", + "triggerdata": types.JsonRaw(triggerSchedule), + "action": "python", + "actiondata": types.JsonRaw(createTicketActionData), + }, + }, + } +} + +func GenerateDefaultData(app core.App) error { + var records []*models.Record + + // users + userRecord, err := testUser(app.Dao()) + if err != nil { + return err + } + + records = append(records, userRecord) + + // records + for collectionName, collectionRecords := range defaultData() { + collection, err := app.Dao().FindCollectionByNameOrId(collectionName) + if err != nil { + return err + } + + for id, fields := range collectionRecords { + record := models.NewRecord(collection) + record.SetId(id) + + for key, value := range fields { + record.Set(key, value) + } + + records = append(records, record) + } + } + + for _, record := range records { + if err := app.Dao().SaveRecord(record); err != nil { + return err + } + } + + return nil +} + +func ValidateDefaultData(app core.App) error { //nolint:cyclop,gocognit + // users + userRecord, err := app.Dao().FindRecordById(migrations.UserCollectionName, "u_test") + if err != nil { + return fmt.Errorf("failed to find user record: %w", err) + } + + if userRecord == nil { + return errors.New("user not found") + } + + if userRecord.Username() != "u_test" { + return fmt.Errorf(`username does not match: got %q, want "u_test"`, userRecord.Username()) + } + + if !userRecord.ValidatePassword("1234567890") { + return errors.New("password does not match") + } + + if userRecord.Get("name") != "Test User" { + return fmt.Errorf(`name does not match: got %q, want "Test User"`, userRecord.Get("name")) + } + + if userRecord.Get("email") != "user@catalyst-soar.com" { + return fmt.Errorf(`email does not match: got %q, want "user@catalyst-soar.com"`, userRecord.Get("email")) + } + + if !userRecord.Verified() { + return errors.New("user is not verified") + } + + // records + for collectionName, collectionRecords := range defaultData() { + for id, fields := range collectionRecords { + record, err := app.Dao().FindRecordById(collectionName, id) + if err != nil { + return fmt.Errorf("failed to find record %s: %w", id, err) + } + + if record == nil { + return errors.New("record not found") + } + + for key, value := range fields { + got := record.Get(key) + + if wantJSON, ok := value.(types.JsonRaw); ok { + if err := compareJSON(got, wantJSON); err != nil { + return fmt.Errorf("record field %q does not match: %w", key, err) + } + + continue + } + + if got != value { + return fmt.Errorf("record field %s does not match: got %v (%T), want %v (%T)", key, got, got, value, value) + } + } + } + } + + return nil +} + +func compareJSON(got any, wantJSON types.JsonRaw) error { + gotJSON, ok := got.(types.JsonRaw) + if !ok { + return fmt.Errorf("got %T, want %T", got, wantJSON) + } + + if !jsonEqual(gotJSON.String(), wantJSON.String()) { + return fmt.Errorf("got %v, want %v", gotJSON, wantJSON) + } + + return nil +} + +func jsonEqual(a, b string) bool { + var objA, objB interface{} + + if err := json.Unmarshal([]byte(a), &objA); err != nil { + return false + } + + if err := json.Unmarshal([]byte(b), &objB); err != nil { + return false + } + + return reflect.DeepEqual(objA, objB) +} + +func dateTime(t time.Time) types.DateTime { + dt := types.DateTime{} + _ = dt.Scan(t) + + return dt +} diff --git a/fakedata/default_test.go b/fakedata/default_test.go new file mode 100644 index 0000000..c8bd718 --- /dev/null +++ b/fakedata/default_test.go @@ -0,0 +1,21 @@ +package fakedata_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/SecurityBrewery/catalyst/fakedata" + catalystTesting "github.com/SecurityBrewery/catalyst/testing" +) + +func TestDefaultData(t *testing.T) { + t.Parallel() + + app, _, cleanup := catalystTesting.App(t) + defer cleanup() + + require.NoError(t, fakedata.GenerateDefaultData(app)) + + require.NoError(t, fakedata.ValidateDefaultData(app)) +} diff --git a/fakedata/records.go b/fakedata/records.go index f0381e2..7562994 100644 --- a/fakedata/records.go +++ b/fakedata/records.go @@ -49,38 +49,45 @@ func Records(app core.App, userCount int, ticketCount int) ([]*models.Record, er return nil, err } - users := userRecords(app.Dao(), userCount) - tickets := ticketRecords(app.Dao(), users, types, ticketCount) - reactions := reactionRecords(app.Dao()) + users, err := userRecords(app.Dao(), userCount) + if err != nil { + return nil, err + } + + tickets, err := ticketRecords(app.Dao(), users, types, ticketCount) + if err != nil { + return nil, err + } + + reactions, err := reactionRecords(app.Dao()) + if err != nil { + return nil, err + } var records []*models.Record records = append(records, users...) - records = append(records, types...) records = append(records, tickets...) records = append(records, reactions...) return records, nil } -func userRecords(dao *daos.Dao, count int) []*models.Record { - collection, err := dao.FindCollectionByNameOrId(migrations.UserCollectionName) - if err != nil { - panic(err) - } - +func userRecords(dao *daos.Dao, count int) ([]*models.Record, error) { records := make([]*models.Record, 0, count) // create the test user if _, err := dao.FindRecordById(migrations.UserCollectionName, "u_test"); err != nil { - record := models.NewRecord(collection) - record.SetId("u_test") - _ = record.SetUsername("u_test") - _ = record.SetPassword("1234567890") - record.Set("name", gofakeit.Name()) - record.Set("email", "user@catalyst-soar.com") - _ = record.SetVerified(true) + testUser, err := testUser(dao) + if err != nil { + return nil, err + } - records = append(records, record) + records = append(records, testUser) + } + + collection, err := dao.FindCollectionByNameOrId(migrations.UserCollectionName) + if err != nil { + return nil, err } for range count - 1 { @@ -95,13 +102,30 @@ func userRecords(dao *daos.Dao, count int) []*models.Record { records = append(records, record) } - return records + return records, nil } -func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*models.Record { +func testUser(dao *daos.Dao) (*models.Record, error) { + collection, err := dao.FindCollectionByNameOrId(migrations.UserCollectionName) + if err != nil { + return nil, err + } + + record := models.NewRecord(collection) + record.SetId("u_test") + _ = record.SetUsername("u_test") + _ = record.SetPassword("1234567890") + record.Set("name", "Test User") + record.Set("email", "user@catalyst-soar.com") + _ = record.SetVerified(true) + + return record, nil +} + +func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) ([]*models.Record, error) { collection, err := dao.FindCollectionByNameOrId(migrations.TicketCollectionName) if err != nil { - panic(err) + return nil, err } records := make([]*models.Record, 0, count) @@ -134,19 +158,42 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m records = append(records, record) // Add comments - records = append(records, commentRecords(dao, users, created, record)...) - records = append(records, timelineRecords(dao, created, record)...) - records = append(records, taskRecords(dao, users, created, record)...) - records = append(records, linkRecords(dao, created, record)...) + comments, err := commentRecords(dao, users, created, record) + if err != nil { + return nil, err + } + + records = append(records, comments...) + + timelines, err := timelineRecords(dao, created, record) + if err != nil { + return nil, err + } + + records = append(records, timelines...) + + tasks, err := taskRecords(dao, users, created, record) + if err != nil { + return nil, err + } + + records = append(records, tasks...) + + links, err := linkRecords(dao, created, record) + if err != nil { + return nil, err + } + + records = append(records, links...) } - return records + return records, nil } -func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record { +func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) ([]*models.Record, error) { commentCollection, err := dao.FindCollectionByNameOrId(migrations.CommentCollectionName) if err != nil { - panic(err) + return nil, err } records := make([]*models.Record, 0, 5) @@ -166,13 +213,13 @@ func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, re records = append(records, commentRecord) } - return records + return records, nil } -func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record { +func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) ([]*models.Record, error) { timelineCollection, err := dao.FindCollectionByNameOrId(migrations.TimelineCollectionName) if err != nil { - panic(err) + return nil, err } records := make([]*models.Record, 0, 5) @@ -192,13 +239,13 @@ func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) [] records = append(records, timelineRecord) } - return records + return records, nil } -func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record { +func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) ([]*models.Record, error) { taskCollection, err := dao.FindCollectionByNameOrId(migrations.TaskCollectionName) if err != nil { - panic(err) + return nil, err } records := make([]*models.Record, 0, 5) @@ -219,13 +266,13 @@ func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, recor records = append(records, taskRecord) } - return records + return records, nil } -func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record { +func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) ([]*models.Record, error) { linkCollection, err := dao.FindCollectionByNameOrId(migrations.LinkCollectionName) if err != nil { - panic(err) + return nil, err } records := make([]*models.Record, 0, 5) @@ -245,7 +292,7 @@ func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) []*mod records = append(records, linkRecord) } - return records + return records, nil } const createTicketPy = `import sys @@ -321,12 +368,12 @@ const ( triggerHook = `{"collections":["tickets"],"events":["create"]}` ) -func reactionRecords(dao *daos.Dao) []*models.Record { +func reactionRecords(dao *daos.Dao) ([]*models.Record, error) { var records []*models.Record collection, err := dao.FindCollectionByNameOrId(migrations.ReactionCollectionName) if err != nil { - panic(err) + return nil, err } createTicketActionData, err := json.Marshal(map[string]interface{}{ @@ -334,7 +381,7 @@ func reactionRecords(dao *daos.Dao) []*models.Record { "script": createTicketPy, }) if err != nil { - panic(err) + return nil, err } record := models.NewRecord(collection) @@ -352,7 +399,7 @@ func reactionRecords(dao *daos.Dao) []*models.Record { "script": alertIngestPy, }) if err != nil { - panic(err) + return nil, err } record = models.NewRecord(collection) @@ -370,7 +417,7 @@ func reactionRecords(dao *daos.Dao) []*models.Record { "script": assignTicketsPy, }) if err != nil { - panic(err) + return nil, err } record = models.NewRecord(collection) @@ -383,5 +430,5 @@ func reactionRecords(dao *daos.Dao) []*models.Record { records = append(records, record) - return records + return records, nil } diff --git a/upgradetest/data/v0.14.1/data.db b/upgradetest/data/v0.14.1/data.db new file mode 100644 index 0000000..97c7c78 Binary files /dev/null and b/upgradetest/data/v0.14.1/data.db differ diff --git a/upgradetest/upgrade_test.go b/upgradetest/upgrade_test.go new file mode 100644 index 0000000..9f5ac23 --- /dev/null +++ b/upgradetest/upgrade_test.go @@ -0,0 +1,44 @@ +package upgradetest + +import ( + "fmt" + "log" + "os" + "path/filepath" + "testing" + + "github.com/SecurityBrewery/catalyst/app" + "github.com/SecurityBrewery/catalyst/fakedata" +) + +func TestUpgrades(t *testing.T) { + t.Parallel() + + dirEntries, err := os.ReadDir("data") + if err != nil { + t.Fatal(err) + } + + for _, entry := range dirEntries { + if !entry.IsDir() { + continue + } + + t.Run(entry.Name(), func(t *testing.T) { + t.Parallel() + + pb, err := app.App(filepath.Join("data", entry.Name()), true) + if err != nil { + log.Fatal(err) + } + + if err := pb.Bootstrap(); err != nil { + t.Fatal(fmt.Errorf("failed to bootstrap: %w", err)) + } + + if err := fakedata.ValidateDefaultData(pb); err != nil { + log.Fatal(err) + } + }) + } +}