diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 698e770..854c886 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,11 @@ jobs: with: { go-version: '1.22' } - uses: oven-sh/setup-bun@v1 - - run: make build-ui + - run: | + bun install + mkdir -p dist + touch dist/index.html + working-directory: ui - run: make install - run: make fmt @@ -28,13 +32,25 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: { go-version: '1.22' } - - uses: oven-sh/setup-bun@v1 - - run: make build-ui + - run: | + mkdir -p ui/dist + touch ui/dist/index.html - uses: golangci/golangci-lint-action@v6 with: { version: 'v1.59' } + build: + name: Build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: { go-version: '1.22' } + - uses: oven-sh/setup-bun@v1 + + - run: make build-ui + test: name: Test runs-on: ubuntu-latest @@ -44,9 +60,9 @@ jobs: with: { go-version: '1.22' } - uses: oven-sh/setup-bun@v1 - - run: make build-ui - - - run: make test + - run: | + mkdir -p ui/dist + touch ui/dist/index.html - run: make test-coverage diff --git a/.golangci.yml b/.golangci.yml index 92d4e29..76f10f1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,39 +5,20 @@ linters: enable-all: true disable: # complexity - - cyclop - - gocognit - - gocyclo - maintidx - - nestif + - funlen # disable - - bodyclose - depguard - - dupl - err113 - - execinquery - exhaustruct - - funlen - - gochecknoglobals - - gochecknoinits - - goconst - - godox - gomnd - - gomoddirectives - ireturn - lll - - makezero - mnd - - paralleltest - - perfsprint - - prealloc - - tagalign - - tagliatelle - testpackage - varnamelen - wrapcheck - - wsl linters-settings: gci: sections: diff --git a/Makefile b/Makefile index 3881103..cbc206f 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,6 @@ build-ui: dev: @echo "Running..." rm -rf catalyst_data - go run . bootstrap go run . admin create admin@catalyst-soar.com 1234567890 go run . set-feature-flags dev go run . fake-data diff --git a/app/app.go b/app/app.go index 68e80b4..50bcfa3 100644 --- a/app/app.go +++ b/app/app.go @@ -5,38 +5,40 @@ import ( "strings" "github.com/pocketbase/pocketbase" - "github.com/pocketbase/pocketbase/core" "github.com/SecurityBrewery/catalyst/migrations" "github.com/SecurityBrewery/catalyst/reaction" "github.com/SecurityBrewery/catalyst/webhook" ) -func init() { +func init() { //nolint:gochecknoinits migrations.Register() } -func App(dir string) *pocketbase.PocketBase { +func App(dir string, test bool) (*pocketbase.PocketBase, error) { app := pocketbase.NewWithConfig(pocketbase.Config{ - DefaultDev: dev(), + DefaultDev: test || dev(), DefaultDataDir: dir, }) - BindHooks(app) + webhook.BindHooks(app) + reaction.BindHooks(app, test) + + app.OnBeforeServe().Add(addRoutes()) // Register additional commands - app.RootCmd.AddCommand(bootstrapCmd(app)) app.RootCmd.AddCommand(fakeDataCmd(app)) app.RootCmd.AddCommand(setFeatureFlagsCmd(app)) - return app -} + if err := app.Bootstrap(); err != nil { + return nil, err + } -func BindHooks(app core.App) { - webhook.BindHooks(app) - reaction.BindHooks(app) + if err := MigrateDBs(app); err != nil { + return nil, err + } - app.OnBeforeServe().Add(addRoutes()) + return app, nil } func dev() bool { diff --git a/app/bootstrap.go b/app/bootstrap.go deleted file mode 100644 index 684efc5..0000000 --- a/app/bootstrap.go +++ /dev/null @@ -1,25 +0,0 @@ -package app - -import ( - "github.com/pocketbase/pocketbase/core" - "github.com/spf13/cobra" -) - -func Bootstrap(app core.App) error { - if err := app.Bootstrap(); err != nil { - return err - } - - return MigrateDBs(app) -} - -func bootstrapCmd(app core.App) *cobra.Command { - return &cobra.Command{ - Use: "bootstrap", - Run: func(_ *cobra.Command, _ []string) { - if err := Bootstrap(app); err != nil { - app.Logger().Error(err.Error()) - } - }, - } -} diff --git a/app/flags.go b/app/flags.go index 4267c31..9a4b749 100644 --- a/app/flags.go +++ b/app/flags.go @@ -16,7 +16,7 @@ func Flags(app core.App) ([]string, error) { return nil, err } - var flags []string + flags := make([]string, 0, len(records)) for _, r := range records { flags = append(flags, r.GetString("name")) @@ -36,7 +36,7 @@ func SetFlags(app core.App, args []string) error { return err } - var existingFlags []string + var existingFlags []string //nolint:prealloc for _, featureRecord := range featureRecords { // remove feature flags that are not in the args diff --git a/app/flags_test.go b/app/flags_test.go index 0256ff5..c814f90 100644 --- a/app/flags_test.go +++ b/app/flags_test.go @@ -11,7 +11,9 @@ import ( ) func Test_flags(t *testing.T) { - catalystApp, cleanup := catalystTesting.App(t) + t.Parallel() + + catalystApp, _, cleanup := catalystTesting.App(t) defer cleanup() got, err := app.Flags(catalystApp) @@ -22,9 +24,12 @@ func Test_flags(t *testing.T) { } func Test_setFlags(t *testing.T) { - catalystApp, cleanup := catalystTesting.App(t) + t.Parallel() + + catalystApp, _, cleanup := catalystTesting.App(t) defer cleanup() + // stage 1 require.NoError(t, app.SetFlags(catalystApp, []string{"test"})) got, err := app.Flags(catalystApp) @@ -32,10 +37,19 @@ func Test_setFlags(t *testing.T) { assert.ElementsMatch(t, []string{"test"}, got) + // stage 2 require.NoError(t, app.SetFlags(catalystApp, []string{"test2"})) got, err = app.Flags(catalystApp) require.NoError(t, err) assert.ElementsMatch(t, []string{"test2"}, got) + + // stage 3 + require.NoError(t, app.SetFlags(catalystApp, []string{"test", "test2"})) + + got, err = app.Flags(catalystApp) + require.NoError(t, err) + + assert.ElementsMatch(t, []string{"test", "test2"}, got) } diff --git a/app/migrate.go b/app/migrate.go index ddc03f7..46222fc 100644 --- a/app/migrate.go +++ b/app/migrate.go @@ -33,13 +33,13 @@ func MigrateDBs(app core.App) error { return nil } -// this fix ignores some errors that come from upstream migrations. -var ignoreErrors = []string{ - "1673167670_multi_match_migrate", - "1660821103_add_user_ip_column", -} - func isIgnored(err error) bool { + // this fix ignores some errors that come from upstream migrations. + ignoreErrors := []string{ + "1673167670_multi_match_migrate", + "1660821103_add_user_ip_column", + } + for _, ignore := range ignoreErrors { if strings.Contains(err.Error(), ignore) { return true diff --git a/app/migrate_internal_test.go b/app/migrate_internal_test.go new file mode 100644 index 0000000..c277919 --- /dev/null +++ b/app/migrate_internal_test.go @@ -0,0 +1,39 @@ +package app + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_isIgnored(t *testing.T) { + t.Parallel() + + type args struct { + err error + } + + tests := []struct { + name string + args args + want bool + }{ + { + name: "error is ignored", + args: args{err: errors.New("1673167670_multi_match_migrate")}, + want: true, + }, + { + name: "error is not ignored", + args: args{err: errors.New("1673167670_multi_match")}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equalf(t, tt.want, isIgnored(tt.args.err), "isIgnored(%v)", tt.args.err) + }) + } +} diff --git a/app/migrate_test.go b/app/migrate_test.go index 442cb58..158b78f 100644 --- a/app/migrate_test.go +++ b/app/migrate_test.go @@ -11,7 +11,9 @@ import ( ) func Test_MigrateDBsDown(t *testing.T) { - catalystApp, cleanup := catalystTesting.App(t) + t.Parallel() + + catalystApp, _, cleanup := catalystTesting.App(t) defer cleanup() _, err := catalystApp.Dao().FindCollectionByNameOrId(migrations.ReactionCollectionName) diff --git a/app/routes.go b/app/routes.go index d3bc493..dcd9dc3 100644 --- a/app/routes.go +++ b/app/routes.go @@ -38,11 +38,10 @@ func staticFiles() func(echo.Context) error { return func(c echo.Context) error { if dev() { u, _ := url.Parse("http://localhost:3000/") - proxy := httputil.NewSingleHostReverseProxy(u) c.Request().Host = c.Request().URL.Host - proxy.ServeHTTP(c.Response(), c.Request()) + httputil.NewSingleHostReverseProxy(u).ServeHTTP(c.Response(), c.Request()) return nil } diff --git a/app/routes_test.go b/app/routes_test.go new file mode 100644 index 0000000..497ff14 --- /dev/null +++ b/app/routes_test.go @@ -0,0 +1,21 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/require" +) + +func Test_staticFiles(t *testing.T) { + t.Parallel() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, staticFiles()(c)) +} diff --git a/fakedata/records.go b/fakedata/records.go index f1b9be8..827072c 100644 --- a/fakedata/records.go +++ b/fakedata/records.go @@ -1,6 +1,7 @@ package fakedata import ( + "encoding/json" "fmt" "strings" "time" @@ -50,14 +51,12 @@ func Records(app core.App, userCount int, ticketCount int) ([]*models.Record, er users := userRecords(app.Dao(), userCount) tickets := ticketRecords(app.Dao(), users, types, ticketCount) - webhooks := webhookRecords(app.Dao()) reactions := reactionRecords(app.Dao()) var records []*models.Record records = append(records, users...) records = append(records, types...) records = append(records, tickets...) - records = append(records, webhooks...) records = append(records, reactions...) return records, nil @@ -69,7 +68,7 @@ func userRecords(dao *daos.Dao, count int) []*models.Record { panic(err) } - var records []*models.Record + records := make([]*models.Record, 0, count) // create the test user if _, err := dao.FindRecordById(migrations.UserCollectionName, "u_test"); err != nil { @@ -105,7 +104,7 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m panic(err) } - var records []*models.Record + records := make([]*models.Record, 0, count) created := time.Now() number := gofakeit.Number(200*count, 300*count) @@ -135,114 +134,168 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m records = append(records, record) // Add comments - for range gofakeit.IntN(5) { - commentCollection, err := dao.FindCollectionByNameOrId(migrations.CommentCollectionName) - if err != nil { - panic(err) - } - - commentCreated := gofakeit.DateRange(created, time.Now()) - commentUpdated := gofakeit.DateRange(commentCreated, time.Now()) - - commentRecord := models.NewRecord(commentCollection) - commentRecord.SetId("c_" + security.PseudorandomString(10)) - commentRecord.Set("created", commentCreated.Format("2006-01-02T15:04:05Z")) - commentRecord.Set("updated", commentUpdated.Format("2006-01-02T15:04:05Z")) - commentRecord.Set("ticket", record.GetId()) - commentRecord.Set("author", random(users).GetId()) - commentRecord.Set("message", fakeTicketComment()) - - records = append(records, commentRecord) - } - - // Add timeline - for range gofakeit.IntN(5) { - timelineCollection, err := dao.FindCollectionByNameOrId(migrations.TimelineCollectionName) - if err != nil { - panic(err) - } - - timelineCreated := gofakeit.DateRange(created, time.Now()) - timelineUpdated := gofakeit.DateRange(timelineCreated, time.Now()) - - timelineRecord := models.NewRecord(timelineCollection) - timelineRecord.SetId("tl_" + security.PseudorandomString(10)) - timelineRecord.Set("created", timelineCreated.Format("2006-01-02T15:04:05Z")) - timelineRecord.Set("updated", timelineUpdated.Format("2006-01-02T15:04:05Z")) - timelineRecord.Set("ticket", record.GetId()) - timelineRecord.Set("time", gofakeit.DateRange(created, time.Now()).Format("2006-01-02T15:04:05Z")) - timelineRecord.Set("message", fakeTicketTimelineMessage()) - - records = append(records, timelineRecord) - } - - // Add tasks - for range gofakeit.IntN(5) { - taskCollection, err := dao.FindCollectionByNameOrId(migrations.TaskCollectionName) - if err != nil { - panic(err) - } - - taskCreated := gofakeit.DateRange(created, time.Now()) - taskUpdated := gofakeit.DateRange(taskCreated, time.Now()) - - taskRecord := models.NewRecord(taskCollection) - taskRecord.SetId("ts_" + security.PseudorandomString(10)) - taskRecord.Set("created", taskCreated.Format("2006-01-02T15:04:05Z")) - taskRecord.Set("updated", taskUpdated.Format("2006-01-02T15:04:05Z")) - taskRecord.Set("ticket", record.GetId()) - taskRecord.Set("name", fakeTicketTask()) - taskRecord.Set("open", gofakeit.Bool()) - taskRecord.Set("owner", random(users).GetId()) - - records = append(records, taskRecord) - } - - // Add links - for range gofakeit.IntN(5) { - linkCollection, err := dao.FindCollectionByNameOrId(migrations.LinkCollectionName) - if err != nil { - panic(err) - } - - linkCreated := gofakeit.DateRange(created, time.Now()) - linkUpdated := gofakeit.DateRange(linkCreated, time.Now()) - - linkRecord := models.NewRecord(linkCollection) - linkRecord.SetId("l_" + security.PseudorandomString(10)) - linkRecord.Set("created", linkCreated.Format("2006-01-02T15:04:05Z")) - linkRecord.Set("updated", linkUpdated.Format("2006-01-02T15:04:05Z")) - linkRecord.Set("ticket", record.GetId()) - linkRecord.Set("url", gofakeit.URL()) - linkRecord.Set("name", random([]string{"Blog", "Forum", "Wiki", "Documentation"})) - - records = append(records, linkRecord) - } + records = append(records, commentRecords(dao, users, created, record)...) + records = append(records, timelineRecords(dao, created, record)...) + records = append(records, taskRecords(dao, users, created, record)...) + records = append(records, linkRecords(dao, created, record)...) } return records } -func webhookRecords(dao *daos.Dao) []*models.Record { - collection, err := dao.FindCollectionByNameOrId(migrations.WebhookCollectionName) +func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record { + commentCollection, err := dao.FindCollectionByNameOrId(migrations.CommentCollectionName) if err != nil { panic(err) } - record := models.NewRecord(collection) - record.SetId("w_" + security.PseudorandomString(10)) - record.Set("name", "Test Webhook") - record.Set("collection", "tickets") - record.Set("destination", "http://localhost:8080/webhook") + records := make([]*models.Record, 0, 5) - return []*models.Record{record} + for range gofakeit.IntN(5) { + commentCreated := gofakeit.DateRange(created, time.Now()) + commentUpdated := gofakeit.DateRange(commentCreated, time.Now()) + + commentRecord := models.NewRecord(commentCollection) + commentRecord.SetId("c_" + security.PseudorandomString(10)) + commentRecord.Set("created", commentCreated.Format("2006-01-02T15:04:05Z")) + commentRecord.Set("updated", commentUpdated.Format("2006-01-02T15:04:05Z")) + commentRecord.Set("ticket", record.GetId()) + commentRecord.Set("author", random(users).GetId()) + commentRecord.Set("message", fakeTicketComment()) + + records = append(records, commentRecord) + } + + return records } +func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record { + timelineCollection, err := dao.FindCollectionByNameOrId(migrations.TimelineCollectionName) + if err != nil { + panic(err) + } + + records := make([]*models.Record, 0, 5) + + for range gofakeit.IntN(5) { + timelineCreated := gofakeit.DateRange(created, time.Now()) + timelineUpdated := gofakeit.DateRange(timelineCreated, time.Now()) + + timelineRecord := models.NewRecord(timelineCollection) + timelineRecord.SetId("tl_" + security.PseudorandomString(10)) + timelineRecord.Set("created", timelineCreated.Format("2006-01-02T15:04:05Z")) + timelineRecord.Set("updated", timelineUpdated.Format("2006-01-02T15:04:05Z")) + timelineRecord.Set("ticket", record.GetId()) + timelineRecord.Set("time", gofakeit.DateRange(created, time.Now()).Format("2006-01-02T15:04:05Z")) + timelineRecord.Set("message", fakeTicketTimelineMessage()) + + records = append(records, timelineRecord) + } + + return records +} + +func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record { + taskCollection, err := dao.FindCollectionByNameOrId(migrations.TaskCollectionName) + if err != nil { + panic(err) + } + + records := make([]*models.Record, 0, 5) + + for range gofakeit.IntN(5) { + taskCreated := gofakeit.DateRange(created, time.Now()) + taskUpdated := gofakeit.DateRange(taskCreated, time.Now()) + + taskRecord := models.NewRecord(taskCollection) + taskRecord.SetId("ts_" + security.PseudorandomString(10)) + taskRecord.Set("created", taskCreated.Format("2006-01-02T15:04:05Z")) + taskRecord.Set("updated", taskUpdated.Format("2006-01-02T15:04:05Z")) + taskRecord.Set("ticket", record.GetId()) + taskRecord.Set("name", fakeTicketTask()) + taskRecord.Set("open", gofakeit.Bool()) + taskRecord.Set("owner", random(users).GetId()) + + records = append(records, taskRecord) + } + + return records +} + +func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record { + linkCollection, err := dao.FindCollectionByNameOrId(migrations.LinkCollectionName) + if err != nil { + panic(err) + } + + records := make([]*models.Record, 0, 5) + + for range gofakeit.IntN(5) { + linkCreated := gofakeit.DateRange(created, time.Now()) + linkUpdated := gofakeit.DateRange(linkCreated, time.Now()) + + linkRecord := models.NewRecord(linkCollection) + linkRecord.SetId("l_" + security.PseudorandomString(10)) + linkRecord.Set("created", linkCreated.Format("2006-01-02T15:04:05Z")) + linkRecord.Set("updated", linkUpdated.Format("2006-01-02T15:04:05Z")) + linkRecord.Set("ticket", record.GetId()) + linkRecord.Set("url", gofakeit.URL()) + linkRecord.Set("name", random([]string{"Blog", "Forum", "Wiki", "Documentation"})) + + records = append(records, linkRecord) + } + + return records +} + +const alertIngestPy = `import sys +import json +import random +import os + +from pocketbase import PocketBase + +# Parse the event from the webhook payload +event = json.loads(sys.argv[1]) +body = json.loads(event["body"]) + +# Connect to the PocketBase server +client = PocketBase('http://127.0.0.1:8090') +client.auth_store.save(token=os.environ["CATALYST_TOKEN"]) + +# Create a new ticket +client.collection("tickets").create({ + "name": body["name"], + "type": "alert", + "open": True, +})` + +const assignTicketsPy = `import sys +import json +import random +import os + +from pocketbase import PocketBase + +# Parse the ticket from the input +ticket = json.loads(sys.argv[1]) + +# Connect to the PocketBase server +client = PocketBase('http://127.0.0.1:8090') +client.auth_store.save(token=os.environ["CATALYST_TOKEN"]) + +# Get a random user +users = client.collection("users").get_list(1, 200) +random_user = random.choice(users.items) + +# Assign the ticket to the random user +client.collection("tickets").update(ticket["record"]["id"], { + "owner": random_user.id, +})` + const ( - triggerWebhook = `{"token":"1234567890","path":"webhook"}` - reactionPython = `{"requirements":"requests","script":"import sys\n\nprint(sys.argv[1])"}` - triggerHook = `{"collections":["tickets","comments"],"events":["create","update","delete"]}` - reactionWebhook = `{"headers":["Content-Type: application/json"],"url":"http://localhost:8080/webhook"}` + triggerWebhook = `{"token":"1234567890","path":"webhook"}` + triggerHook = `{"collections":["tickets"],"events":["create"]}` ) func reactionRecords(dao *daos.Dao) []*models.Record { @@ -253,23 +306,39 @@ func reactionRecords(dao *daos.Dao) []*models.Record { panic(err) } + alertIngestActionData, err := json.Marshal(map[string]interface{}{ + "requirements": "pocketbase", + "script": alertIngestPy, + }) + if err != nil { + panic(err) + } + record := models.NewRecord(collection) record.SetId("w_" + security.PseudorandomString(10)) record.Set("name", "Test Reaction") record.Set("trigger", "webhook") record.Set("triggerdata", triggerWebhook) record.Set("action", "python") - record.Set("actiondata", reactionPython) + record.Set("actiondata", string(alertIngestActionData)) records = append(records, record) + assignTicketsActionData, err := json.Marshal(map[string]interface{}{ + "requirements": "pocketbase", + "script": assignTicketsPy, + }) + if err != nil { + panic(err) + } + record = models.NewRecord(collection) record.SetId("w_" + security.PseudorandomString(10)) record.Set("name", "Test Reaction 2") record.Set("trigger", "hook") record.Set("triggerdata", triggerHook) - record.Set("action", "webhook") - record.Set("actiondata", reactionWebhook) + record.Set("action", "python") + record.Set("actiondata", string(assignTicketsActionData)) records = append(records, record) diff --git a/fakedata/records_test.go b/fakedata/records_test.go index d540ece..304826d 100644 --- a/fakedata/records_test.go +++ b/fakedata/records_test.go @@ -11,7 +11,9 @@ import ( ) func Test_records(t *testing.T) { - app, cleanup := catalystTesting.App(t) + t.Parallel() + + app, _, cleanup := catalystTesting.App(t) defer cleanup() got, err := fakedata.Records(app, 2, 2) @@ -21,7 +23,9 @@ func Test_records(t *testing.T) { } func TestGenerate(t *testing.T) { - app, cleanup := catalystTesting.App(t) + t.Parallel() + + app, _, cleanup := catalystTesting.App(t) defer cleanup() err := fakedata.Generate(app, 0, 0) diff --git a/fakedata/text_test.go b/fakedata/text_test.go index 9ffce2f..bc08c0d 100644 --- a/fakedata/text_test.go +++ b/fakedata/text_test.go @@ -7,22 +7,32 @@ import ( ) func Test_fakeTicketComment(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, fakeTicketComment()) } func Test_fakeTicketDescription(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, fakeTicketDescription()) } func Test_fakeTicketTask(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, fakeTicketTask()) } func Test_fakeTicketTimelineMessage(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, fakeTicketTimelineMessage()) } func Test_random(t *testing.T) { + t.Parallel() + type args[T any] struct { e []T } @@ -40,6 +50,8 @@ func Test_random(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := random(tt.args.e) assert.Contains(t, tt.args.e, got) diff --git a/go.mod b/go.mod index 3d10758..d8ffe57 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.1 require ( github.com/brianvoe/gofakeit/v7 v7.0.3 + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/labstack/echo/v5 v5.0.0-20230722203903-ec5b858dab61 github.com/pocketbase/dbx v1.10.1 github.com/pocketbase/pocketbase v0.22.14 @@ -46,7 +47,6 @@ require ( github.com/go-ozzo/ozzo-validation/v4 v4.3.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goccy/go-json v0.10.3 // indirect - github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect diff --git a/main.go b/main.go index 9b6c95e..1f6c52c 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,12 @@ import ( ) func main() { - if err := app.App("./catalyst_data").Start(); err != nil { + catalyst, err := app.App("./catalyst_data", false) + if err != nil { + log.Fatal(err) + } + + if err := catalyst.Start(); err != nil { log.Fatal(err) } } diff --git a/migrations/3_defaultdata.go b/migrations/3_defaultdata.go index 8170745..2a2ac87 100644 --- a/migrations/3_defaultdata.go +++ b/migrations/3_defaultdata.go @@ -4,7 +4,6 @@ import ( "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/models" - "github.com/pocketbase/pocketbase/tools/security" ) func defaultDataUp(db dbx.Builder) error { @@ -30,7 +29,7 @@ func typeRecords(dao *daos.Dao) []*models.Record { var records []*models.Record record := models.NewRecord(collection) - record.SetId("y_" + security.PseudorandomString(5)) + record.SetId("incident") record.Set("singular", "Incident") record.Set("plural", "Incidents") record.Set("icon", "Flame") @@ -39,7 +38,7 @@ func typeRecords(dao *daos.Dao) []*models.Record { records = append(records, record) record = models.NewRecord(collection) - record.SetId("y_" + security.PseudorandomString(5)) + record.SetId("alert") record.Set("singular", "Alert") record.Set("plural", "Alerts") record.Set("icon", "AlertTriangle") diff --git a/migrations/5_reactions.go b/migrations/5_reactions.go index a8693ff..c4fc90f 100644 --- a/migrations/5_reactions.go +++ b/migrations/5_reactions.go @@ -21,9 +21,9 @@ func reactionsUp(db dbx.Builder) error { Schema: schema.NewSchema( &schema.SchemaField{Name: "name", Type: schema.FieldTypeText, Required: true}, &schema.SchemaField{Name: "trigger", Type: schema.FieldTypeSelect, Required: true, Options: &schema.SelectOptions{MaxSelect: 1, Values: triggers}}, - &schema.SchemaField{Name: "triggerdata", Type: schema.FieldTypeJson, Required: true}, + &schema.SchemaField{Name: "triggerdata", Type: schema.FieldTypeJson, Required: true, Options: &schema.JsonOptions{MaxSize: 50_000}}, &schema.SchemaField{Name: "action", Type: schema.FieldTypeSelect, Required: true, Options: &schema.SelectOptions{MaxSelect: 1, Values: reactions}}, - &schema.SchemaField{Name: "actiondata", Type: schema.FieldTypeJson, Required: true}, + &schema.SchemaField{Name: "actiondata", Type: schema.FieldTypeJson, Required: true, Options: &schema.JsonOptions{MaxSize: 50_000}}, ), })) } diff --git a/migrations/6_systemuser.go b/migrations/6_systemuser.go new file mode 100644 index 0000000..0ee2410 --- /dev/null +++ b/migrations/6_systemuser.go @@ -0,0 +1,37 @@ +package migrations + +import ( + "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/daos" + "github.com/pocketbase/pocketbase/models" +) + +const SystemUserID = "system" + +func systemuserUp(db dbx.Builder) error { + dao := daos.New(db) + + collection, err := dao.FindCollectionByNameOrId(UserCollectionName) + if err != nil { + return err + } + + record := models.NewRecord(collection) + record.SetId(SystemUserID) + record.Set("name", "system") + record.Set("username", "system") + record.Set("verified", true) + + return dao.SaveRecord(record) +} + +func systemuserDown(db dbx.Builder) error { + dao := daos.New(db) + + record, err := dao.FindRecordById(UserCollectionName, SystemUserID) + if err != nil { + return err + } + + return dao.DeleteRecord(record) +} diff --git a/migrations/migrations.go b/migrations/migrations.go index 7632d1f..c8650a9 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -10,4 +10,5 @@ func Register() { migrations.Register(defaultDataUp, nil, "1700000003_defaultdata.go") migrations.Register(viewsUp, viewsDown, "1700000004_views.go") migrations.Register(reactionsUp, reactionsDown, "1700000005_reactions.go") + migrations.Register(systemuserUp, systemuserDown, "1700000006_systemuser.go") } diff --git a/reaction/action/action.go b/reaction/action/action.go index a564134..6539305 100644 --- a/reaction/action/action.go +++ b/reaction/action/action.go @@ -4,17 +4,33 @@ import ( "context" "encoding/json" "fmt" + "time" + "github.com/golang-jwt/jwt/v4" + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tokens" + "github.com/pocketbase/pocketbase/tools/security" + + "github.com/SecurityBrewery/catalyst/migrations" "github.com/SecurityBrewery/catalyst/reaction/action/python" "github.com/SecurityBrewery/catalyst/reaction/action/webhook" ) -func Run(ctx context.Context, actionName, actionData, payload string) ([]byte, error) { +func Run(ctx context.Context, app core.App, actionName, actionData, payload string) ([]byte, error) { action, err := decode(actionName, actionData) if err != nil { return nil, err } + if a, ok := action.(authenticatedAction); ok { + token, err := systemToken(app) + if err != nil { + return nil, fmt.Errorf("failed to get system token: %w", err) + } + + a.SetToken(token) + } + return action.Run(ctx, payload) } @@ -22,6 +38,10 @@ type action interface { Run(ctx context.Context, payload string) ([]byte, error) } +type authenticatedAction interface { + SetToken(token string) +} + func decode(actionName, actionData string) (action, error) { switch actionName { case "python": @@ -42,3 +62,20 @@ func decode(actionName, actionData string) (action, error) { return nil, fmt.Errorf("action %q not found", actionName) } } + +func systemToken(app core.App) (string, error) { + authRecord, err := app.Dao().FindAuthRecordByUsername(migrations.UserCollectionName, migrations.SystemUserID) + if err != nil { + return "", fmt.Errorf("failed to find system auth record: %w", err) + } + + return security.NewJWT( + jwt.MapClaims{ + "id": authRecord.Id, + "type": tokens.TypeAuthRecord, + "collectionId": authRecord.Collection().Id, + }, + authRecord.TokenKey()+app.Settings().RecordAuthToken.Secret, + int64(time.Second*60), + ) +} diff --git a/reaction/action/python/python.go b/reaction/action/python/python.go index 1c1cc66..cbac867 100644 --- a/reaction/action/python/python.go +++ b/reaction/action/python/python.go @@ -10,8 +10,14 @@ import ( ) type Python struct { - Bootstrap string `json:"bootstrap"` - Script string `json:"script"` + Requirements string `json:"requirements"` + Script string `json:"script"` + + token string +} + +func (a *Python) SetToken(token string) { + a.token = token } func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) { @@ -22,7 +28,8 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) { defer os.RemoveAll(tempDir) - if b, err := pythonSetup(ctx, tempDir); err != nil { + b, err := pythonSetup(ctx, tempDir) + if err != nil { var ee *exec.ExitError if errors.As(err, &ee) { b = append(b, ee.Stderr...) @@ -31,16 +38,17 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) { return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b)) } - if b, err := pythonRunBootstrap(ctx, tempDir, a.Bootstrap); err != nil { + b, err = a.pythonInstallRequirements(ctx, tempDir) + if err != nil { var ee *exec.ExitError if errors.As(err, &ee) { b = append(b, ee.Stderr...) } - return nil, fmt.Errorf("failed to run bootstrap, %w: %s", err, string(b)) + return nil, fmt.Errorf("failed to run install requirements, %w: %s", err, string(b)) } - b, err := pythonRunScript(ctx, tempDir, a.Script, payload) + b, err = a.pythonRunScript(ctx, tempDir, payload) if err != nil { var ee *exec.ExitError if errors.As(err, &ee) { @@ -63,35 +71,42 @@ func pythonSetup(ctx context.Context, tempDir string) ([]byte, error) { return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output() } -func pythonRunBootstrap(ctx context.Context, tempDir, bootstrap string) ([]byte, error) { - hasBootstrap := len(strings.TrimSpace(bootstrap)) > 0 +func (a *Python) pythonInstallRequirements(ctx context.Context, tempDir string) ([]byte, error) { + hasRequirements := len(strings.TrimSpace(a.Requirements)) > 0 - if !hasBootstrap { + if !hasRequirements { return nil, nil } - bootstrapPath := tempDir + "/requirements.txt" + requirementsPath := tempDir + "/requirements.txt" - if err := os.WriteFile(bootstrapPath, []byte(bootstrap), 0o600); err != nil { + if err := os.WriteFile(requirementsPath, []byte(a.Requirements), 0o600); err != nil { return nil, err } // install dependencies pipPath := tempDir + "/venv/bin/pip" - return exec.CommandContext(ctx, pipPath, "install", "-r", bootstrapPath).Output() + return exec.CommandContext(ctx, pipPath, "install", "-r", requirementsPath).Output() } -func pythonRunScript(ctx context.Context, tempDir, script, payload string) ([]byte, error) { +func (a *Python) pythonRunScript(ctx context.Context, tempDir, payload string) ([]byte, error) { scriptPath := tempDir + "/script.py" - if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil { + if err := os.WriteFile(scriptPath, []byte(a.Script), 0o600); err != nil { return nil, err } pythonPath := tempDir + "/venv/bin/python" - return exec.CommandContext(ctx, pythonPath, scriptPath, payload).Output() + cmd := exec.CommandContext(ctx, pythonPath, scriptPath, payload) + + cmd.Env = []string{} + if a.token != "" { + cmd.Env = append(cmd.Env, "CATALYST_TOKEN="+a.token) + } + + return cmd.Output() } func findExec(name ...string) (string, error) { diff --git a/reaction/action/python/python_test.go b/reaction/action/python/python_test.go index 7615f96..60f2fc5 100644 --- a/reaction/action/python/python_test.go +++ b/reaction/action/python/python_test.go @@ -1,16 +1,20 @@ -package python +package python_test import ( "context" "testing" "github.com/stretchr/testify/assert" + + "github.com/SecurityBrewery/catalyst/reaction/action/python" ) func TestPython_Run(t *testing.T) { + t.Parallel() + type fields struct { - Bootstrap string - Script string + Requirements string + Script string } type args struct { @@ -68,14 +72,28 @@ func TestPython_Run(t *testing.T) { want: nil, wantErr: assert.Error, }, + { + name: "requests", + fields: fields{ + Requirements: "requests", + Script: "import requests\nprint(requests.get('https://xkcd.com/2961/info.0.json').text)", + }, + args: args{ + payload: "test", + }, + want: []byte("{\"month\": \"7\", \"num\": 2961, \"link\": \"\", \"year\": \"2024\", \"news\": \"\", \"safe_title\": \"CrowdStrike\", \"transcript\": \"\", \"alt\": \"We were going to try swordfighting, but all my compiling is on hold.\", \"img\": \"https://imgs.xkcd.com/comics/crowdstrike.png\", \"title\": \"CrowdStrike\", \"day\": \"19\"}\n"), + wantErr: assert.NoError, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() - a := &Python{ - Bootstrap: tt.fields.Bootstrap, - Script: tt.fields.Script, + a := &python.Python{ + Requirements: tt.fields.Requirements, + Script: tt.fields.Script, } got, err := a.Run(ctx, tt.args.payload) tt.wantErr(t, err) diff --git a/reaction/action/webhook/payload_test.go b/reaction/action/webhook/payload_test.go index cb838ee..a041d82 100644 --- a/reaction/action/webhook/payload_test.go +++ b/reaction/action/webhook/payload_test.go @@ -1,12 +1,16 @@ -package webhook +package webhook_test import ( "bytes" "io" "testing" + + "github.com/SecurityBrewery/catalyst/reaction/action/webhook" ) func TestEncodeBody(t *testing.T) { + t.Parallel() + type args struct { requestBody io.Reader } @@ -36,7 +40,9 @@ func TestEncodeBody(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, got1 := EncodeBody(tt.args.requestBody) + t.Parallel() + + got, got1 := webhook.EncodeBody(tt.args.requestBody) if got != tt.want { t.Errorf("EncodeBody() got = %v, want %v", got, tt.want) } diff --git a/reaction/action/webhook/webhook_test.go b/reaction/action/webhook/webhook_test.go index eea1087..7594ed6 100644 --- a/reaction/action/webhook/webhook_test.go +++ b/reaction/action/webhook/webhook_test.go @@ -16,11 +16,15 @@ import ( ) func TestWebhook_Run(t *testing.T) { + t.Parallel() + server := catalystTesting.NewRecordingServer() go http.ListenAndServe("127.0.0.1:12347", server) //nolint:gosec,errcheck - time.Sleep(1 * time.Second) + if err := catalystTesting.WaitForStatus("http://127.0.0.1:12347/health", http.StatusOK, 5*time.Second); err != nil { + t.Fatal(err) + } type fields struct { Headers map[string]string @@ -50,10 +54,10 @@ func TestWebhook_Run(t *testing.T) { want: map[string]any{ "statusCode": 200, "headers": map[string]any{ - "Content-Length": []any{"13"}, - "Content-Type": []any{"text/plain; charset=utf-8"}, + "Content-Length": []any{"14"}, + "Content-Type": []any{"application/json; charset=UTF-8"}, }, - "body": `{"test":true}`, + "body": "{\"test\":true}\n", "isBase64Encoded": false, }, wantErr: assert.NoError, @@ -61,6 +65,8 @@ func TestWebhook_Run(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() a := &webhook.Webhook{ diff --git a/reaction/trigger.go b/reaction/trigger.go index 235522c..11d1c66 100644 --- a/reaction/trigger.go +++ b/reaction/trigger.go @@ -1,13 +1,13 @@ package reaction import ( - "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase" "github.com/SecurityBrewery/catalyst/reaction/trigger/hook" "github.com/SecurityBrewery/catalyst/reaction/trigger/webhook" ) -func BindHooks(app core.App) { - hook.BindHooks(app) - webhook.BindHooks(app) +func BindHooks(pb *pocketbase.PocketBase, test bool) { + hook.BindHooks(pb, test) + webhook.BindHooks(pb) } diff --git a/reaction/trigger/hook/hook.go b/reaction/trigger/hook/hook.go index 36bf048..6058ef2 100644 --- a/reaction/trigger/hook/hook.go +++ b/reaction/trigger/hook/hook.go @@ -1,12 +1,15 @@ package hook import ( + "context" "encoding/json" "fmt" + "log/slog" "slices" "github.com/labstack/echo/v5" "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/daos" @@ -22,43 +25,40 @@ type Hook struct { Events []string `json:"events"` } -func BindHooks(app core.App) { - app.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error { - if err := hook(app.Dao(), "create", e.Collection.Name, e.Record, e.HttpContext); err != nil { - app.Logger().Error("failed to find hook reaction", "error", err.Error()) - } - - return nil +func BindHooks(pb *pocketbase.PocketBase, test bool) { + pb.App.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error { + return hook(e.HttpContext, pb.App, "create", e.Collection.Name, e.Record, test) }) - app.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error { - if err := hook(app.Dao(), "update", e.Collection.Name, e.Record, e.HttpContext); err != nil { - app.Logger().Error("failed to find hook reaction", "error", err.Error()) - } - - return nil + pb.App.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error { + return hook(e.HttpContext, pb.App, "update", e.Collection.Name, e.Record, test) }) - app.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error { - if err := hook(app.Dao(), "delete", e.Collection.Name, e.Record, e.HttpContext); err != nil { - app.Logger().Error("failed to find hook reaction", "error", err.Error()) - } - - return nil + pb.App.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error { + return hook(e.HttpContext, pb.App, "delete", e.Collection.Name, e.Record, test) }) } -func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx echo.Context) error { +func hook(ctx echo.Context, app core.App, event, collection string, record *models.Record, test bool) error { auth, _ := ctx.Get(apis.ContextAuthRecordKey).(*models.Record) admin, _ := ctx.Get(apis.ContextAdminKey).(*models.Admin) - hook, found, err := findByHookTrigger(dao, collection, event) - if err != nil { - return fmt.Errorf("failed to find hook reaction: %w", err) + if !test { + go mustRunHook(app, collection, event, record, auth, admin) + } else { + mustRunHook(app, collection, event, record, auth, admin) } - if !found { - return nil - } + return nil +} +func mustRunHook(app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) { + ctx := context.Background() + + if err := runHook(ctx, app, collection, event, record, auth, admin); err != nil { + slog.ErrorContext(ctx, fmt.Sprintf("failed to run hook reaction: %v", err)) + } +} + +func runHook(ctx context.Context, app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) error { payload, err := json.Marshal(&webhook.Payload{ Action: event, Collection: collection, @@ -67,10 +67,19 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec Admin: admin, }) if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) + return fmt.Errorf("failed to marshal webhook payload: %w", err) } - _, err = action.Run(ctx.Request().Context(), hook.GetString("action"), hook.GetString("actiondata"), string(payload)) + hook, found, err := findByHookTrigger(app.Dao(), collection, event) + if err != nil { + return fmt.Errorf("failed to find hook by trigger: %w", err) + } + + if !found { + return nil + } + + _, err = action.Run(ctx, app, hook.GetString("action"), hook.GetString("actiondata"), string(payload)) if err != nil { return fmt.Errorf("failed to run hook reaction: %w", err) } @@ -81,7 +90,7 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec func findByHookTrigger(dao *daos.Dao, collection, event string) (*models.Record, bool, error) { records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "hook"}) if err != nil { - return nil, false, err + return nil, false, fmt.Errorf("failed to find hook reaction: %w", err) } if len(records) == 0 { diff --git a/reaction/trigger/webhook/request.go b/reaction/trigger/webhook/request.go index af19c98..33d8cb6 100644 --- a/reaction/trigger/webhook/request.go +++ b/reaction/trigger/webhook/request.go @@ -15,8 +15,8 @@ type Request struct { IsBase64Encoded bool `json:"isBase64Encoded"` } -// isJSON checks if the data is JSON. -func isJSON(data []byte) bool { +// IsJSON checks if the data is JSON. +func IsJSON(data []byte) bool { var msg json.RawMessage return json.Unmarshal(data, &msg) == nil diff --git a/reaction/trigger/webhook/request_test.go b/reaction/trigger/webhook/request_test.go index cb8dbc2..31dbc03 100644 --- a/reaction/trigger/webhook/request_test.go +++ b/reaction/trigger/webhook/request_test.go @@ -1,8 +1,14 @@ -package webhook +package webhook_test -import "testing" +import ( + "testing" + + "github.com/SecurityBrewery/catalyst/reaction/trigger/webhook" +) func Test_isJSON(t *testing.T) { + t.Parallel() + type args struct { data []byte } @@ -29,7 +35,9 @@ func Test_isJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := isJSON(tt.args.data); got != tt.want { + t.Parallel() + + if got := webhook.IsJSON(tt.args.data); got != tt.want { t.Errorf("isJSON() = %v, want %v", got, tt.want) } }) diff --git a/reaction/trigger/webhook/webhook.go b/reaction/trigger/webhook/webhook.go index b145aca..f92a0ad 100644 --- a/reaction/trigger/webhook/webhook.go +++ b/reaction/trigger/webhook/webhook.go @@ -9,6 +9,7 @@ import ( "github.com/labstack/echo/v5" "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/daos" @@ -26,22 +27,22 @@ type Webhook struct { const prefix = "/reaction/" -func BindHooks(app core.App) { - app.OnBeforeServe().Add(func(e *core.ServeEvent) error { - e.Router.Any(prefix+"*", handle(e.App.Dao())) +func BindHooks(pb *pocketbase.PocketBase) { + pb.OnBeforeServe().Add(func(e *core.ServeEvent) error { + e.Router.Any(prefix+"*", handle(e.App)) return nil }) } -func handle(dao *daos.Dao) func(c echo.Context) error { +func handle(app core.App) func(c echo.Context) error { return func(c echo.Context) error { - record, payload, apiErr := parseRequest(dao, c.Request()) + record, payload, apiErr := parseRequest(app.Dao(), c.Request()) if apiErr != nil { return apiErr } - output, err := action.Run(c.Request().Context(), record.GetString("action"), record.GetString("actiondata"), string(payload)) + output, err := action.Run(c.Request().Context(), app, record.GetString("action"), record.GetString("actiondata"), string(payload)) if err != nil { return apis.NewApiError(http.StatusInternalServerError, err.Error(), nil) } @@ -138,7 +139,7 @@ func writeOutput(c echo.Context, output []byte) error { } } - if isJSON(output) { + if IsJSON(output) { return c.JSON(http.StatusOK, json.RawMessage(output)) } diff --git a/testing/collection_reaction_test.go b/testing/collection_reaction_test.go index 995401c..4a950fe 100644 --- a/testing/collection_reaction_test.go +++ b/testing/collection_reaction_test.go @@ -6,18 +6,16 @@ import ( ) func TestReactionsCollection(t *testing.T) { - baseApp, adminToken, analystToken, baseAppCleanup := BaseApp(t) - defer baseAppCleanup() + t.Parallel() - testSets := []authMatrixText{ + testSets := []catalystTest{ { baseTest: BaseTest{ - Name: "ListReactions", - Method: http.MethodGet, - URL: "/api/collections/reactions/records", - TestAppFactory: AppFactory(baseApp), + Name: "ListReactions", + Method: http.MethodGet, + URL: "/api/collections/reactions/records", }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusOK, @@ -29,7 +27,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"totalItems":3`, @@ -42,7 +40,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"totalItems":3`, @@ -68,9 +66,8 @@ func TestReactionsCollection(t *testing.T) { "action": "python", "actiondata": map[string]any{"script": "print('Hello, World!')"}, }), - TestAppFactory: AppFactory(baseApp), }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusBadRequest, @@ -80,7 +77,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"name":"test"`, @@ -97,7 +94,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"name":"test"`, @@ -120,9 +117,8 @@ func TestReactionsCollection(t *testing.T) { Method: http.MethodGet, RequestHeaders: map[string]string{"Content-Type": "application/json"}, URL: "/api/collections/reactions/records/r_reaction", - TestAppFactory: AppFactory(baseApp), }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusNotFound, @@ -132,7 +128,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"id":"r_reaction"`, @@ -141,7 +137,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"id":"r_reaction"`, @@ -157,9 +153,8 @@ func TestReactionsCollection(t *testing.T) { RequestHeaders: map[string]string{"Content-Type": "application/json"}, URL: "/api/collections/reactions/records/r_reaction", Body: s(map[string]any{"name": "update"}), - TestAppFactory: AppFactory(baseApp), }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusNotFound, @@ -169,7 +164,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"id":"r_reaction"`, @@ -184,7 +179,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"id":"r_reaction"`, @@ -201,12 +196,11 @@ func TestReactionsCollection(t *testing.T) { }, { baseTest: BaseTest{ - Name: "DeleteReaction", - Method: http.MethodDelete, - URL: "/api/collections/reactions/records/r_reaction", - TestAppFactory: AppFactory(baseApp), + Name: "DeleteReaction", + Method: http.MethodDelete, + URL: "/api/collections/reactions/records/r_reaction", }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusNotFound, @@ -216,7 +210,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusNoContent, ExpectedEvents: map[string]int{ "OnModelAfterDelete": 1, @@ -227,7 +221,7 @@ func TestReactionsCollection(t *testing.T) { }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusNoContent, ExpectedEvents: map[string]int{ "OnModelAfterDelete": 1, @@ -241,9 +235,14 @@ func TestReactionsCollection(t *testing.T) { } for _, testSet := range testSets { t.Run(testSet.baseTest.Name, func(t *testing.T) { - for _, authBasedExpectation := range testSet.authBasedExpectations { - scenario := mergeScenario(testSet.baseTest, authBasedExpectation) - scenario.Test(t) + t.Parallel() + + for _, userTest := range testSet.userTests { + t.Run(userTest.Name, func(t *testing.T) { + t.Parallel() + + runMatrixTest(t, testSet.baseTest, userTest) + }) } }) } diff --git a/testing/counter.go b/testing/counter.go new file mode 100644 index 0000000..c9c633a --- /dev/null +++ b/testing/counter.go @@ -0,0 +1,36 @@ +package testing + +import "sync" + +type Counter struct { + mux sync.Mutex + counts map[string]int +} + +func NewCounter() *Counter { + return &Counter{ + counts: make(map[string]int), + } +} + +func (c *Counter) Increment(name string) { + c.mux.Lock() + defer c.mux.Unlock() + + if _, ok := c.counts[name]; !ok { + c.counts[name] = 0 + } + + c.counts[name]++ +} + +func (c *Counter) Count(name string) int { + c.mux.Lock() + defer c.mux.Unlock() + + if _, ok := c.counts[name]; !ok { + return 0 + } + + return c.counts[name] +} diff --git a/testing/counter_test.go b/testing/counter_test.go new file mode 100644 index 0000000..3663331 --- /dev/null +++ b/testing/counter_test.go @@ -0,0 +1,41 @@ +package testing + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCounter(t *testing.T) { + t.Parallel() + + type args struct { + name string + repeat int + } + + tests := []struct { + name string + args args + want int + }{ + { + name: "Test Counter", + args: args{name: "test", repeat: 5}, + want: 5, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := NewCounter() + + for range tt.args.repeat { + c.Increment(tt.args.name) + } + + assert.Equal(t, tt.want, c.Count(tt.args.name)) + }) + } +} diff --git a/testing/http.go b/testing/http.go new file mode 100644 index 0000000..faeced0 --- /dev/null +++ b/testing/http.go @@ -0,0 +1,37 @@ +package testing + +import ( + "context" + "errors" + "net/http" + "time" +) + +func WaitForStatus(url string, status int, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + start := time.Now() + + for { + if time.Since(start) > timeout { + return errors.New("timeout") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + + resp, err := http.DefaultClient.Do(req) + if err == nil && resp.StatusCode == status { + resp.Body.Close() + + break + } + + time.Sleep(100 * time.Millisecond) + } + + return nil +} diff --git a/testing/reaction_test.go b/testing/reaction_test.go index a1f0b85..ec4e0f6 100644 --- a/testing/reaction_test.go +++ b/testing/reaction_test.go @@ -3,93 +3,80 @@ package testing import ( "net/http" "testing" + "time" "github.com/stretchr/testify/require" ) func TestWebhookReactions(t *testing.T) { - baseApp, adminToken, analystToken, baseAppCleanup := BaseApp(t) - defer baseAppCleanup() + t.Parallel() server := NewRecordingServer() go http.ListenAndServe("127.0.0.1:12345", server) //nolint:gosec,errcheck - testSets := []authMatrixText{ + if err := WaitForStatus("http://127.0.0.1:12345/health", http.StatusOK, 5*time.Second); err != nil { + t.Fatal(err) + } + + testSets := []catalystTest{ { baseTest: BaseTest{ Name: "TriggerWebhookReaction", Method: http.MethodGet, + RequestHeaders: map[string]string{"Authorization": "Bearer 1234567890"}, URL: "/reaction/test", - TestAppFactory: AppFactory(baseApp), }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusOK, ExpectedContent: []string{`Hello, World!`}, }, - { - Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, - ExpectedStatus: http.StatusOK, - ExpectedContent: []string{`Hello, World!`}, - }, - { - Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, - ExpectedStatus: http.StatusOK, - ExpectedContent: []string{`Hello, World!`}, - }, }, }, { baseTest: BaseTest{ - Name: "TriggerWebhookReaction2", - Method: http.MethodGet, - URL: "/reaction/test2", - TestAppFactory: AppFactory(baseApp), + Name: "TriggerWebhookReaction2", + Method: http.MethodGet, + URL: "/reaction/test2", }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusOK, ExpectedContent: []string{`"test":true`}, }, - { - Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, - ExpectedStatus: http.StatusOK, - ExpectedContent: []string{`"test":true`}, - }, - { - Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, - ExpectedStatus: http.StatusOK, - ExpectedContent: []string{`"test":true`}, - }, }, }, } for _, testSet := range testSets { t.Run(testSet.baseTest.Name, func(t *testing.T) { - for _, authBasedExpectation := range testSet.authBasedExpectations { - scenario := mergeScenario(testSet.baseTest, authBasedExpectation) - scenario.Test(t) + t.Parallel() + + for _, userTest := range testSet.userTests { + t.Run(userTest.Name, func(t *testing.T) { + t.Parallel() + + runMatrixTest(t, testSet.baseTest, userTest) + }) } }) } } func TestHookReactions(t *testing.T) { - baseApp, _, analystToken, baseAppCleanup := BaseApp(t) - defer baseAppCleanup() + t.Parallel() server := NewRecordingServer() go http.ListenAndServe("127.0.0.1:12346", server) //nolint:gosec,errcheck - testSets := []authMatrixText{ + if err := WaitForStatus("http://127.0.0.1:12346/health", http.StatusOK, 5*time.Second); err != nil { + t.Fatal(err) + } + + testSets := []catalystTest{ { baseTest: BaseTest{ Name: "TriggerHookReaction", @@ -99,9 +86,8 @@ func TestHookReactions(t *testing.T) { Body: s(map[string]any{ "name": "test", }), - TestAppFactory: AppFactory(baseApp), }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ // { // Name: "Unauthorized", // ExpectedStatus: http.StatusOK, @@ -109,7 +95,7 @@ func TestHookReactions(t *testing.T) { // }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ `"collectionName":"tickets"`, @@ -133,12 +119,17 @@ func TestHookReactions(t *testing.T) { } for _, testSet := range testSets { t.Run(testSet.baseTest.Name, func(t *testing.T) { - for _, authBasedExpectation := range testSet.authBasedExpectations { - scenario := mergeScenario(testSet.baseTest, authBasedExpectation) - scenario.Test(t) - } + t.Parallel() - require.NotEmpty(t, server.Entries) + for _, userTest := range testSet.userTests { + t.Run(userTest.Name, func(t *testing.T) { + t.Parallel() + + runMatrixTest(t, testSet.baseTest, userTest) + + require.NotEmpty(t, server.Entries) + }) + } }) } } diff --git a/testing/recordingserver.go b/testing/recordingserver.go index b388a01..a15b3f6 100644 --- a/testing/recordingserver.go +++ b/testing/recordingserver.go @@ -1,19 +1,38 @@ package testing -import "net/http" +import ( + "net/http" + + "github.com/labstack/echo/v5" +) type RecordingServer struct { + server *echo.Echo + Entries []string } func NewRecordingServer() *RecordingServer { - return &RecordingServer{} + e := echo.New() + + e.GET("/health", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]any{ + "status": "ok", + }) + }) + e.Any("/*", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]any{ + "test": true, + }) + }) + + return &RecordingServer{ + server: e, + } } func (s *RecordingServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.Entries = append(s.Entries, r.URL.Path) - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"test":true}`)) //nolint:errcheck + s.server.ServeHTTP(w, r) } diff --git a/testing/routes_test.go b/testing/routes_test.go index b0c5ad6..d999d35 100644 --- a/testing/routes_test.go +++ b/testing/routes_test.go @@ -6,63 +6,60 @@ import ( ) func Test_Routes(t *testing.T) { - baseApp, adminToken, analystToken, baseAppCleanup := BaseApp(t) - defer baseAppCleanup() + t.Parallel() - testSets := []authMatrixText{ + testSets := []catalystTest{ { baseTest: BaseTest{ - Name: "Root", - Method: http.MethodGet, - URL: "/", - TestAppFactory: AppFactory(baseApp), + Name: "Root", + Method: http.MethodGet, + URL: "/", }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusFound, }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusFound, }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusFound, }, }, }, { baseTest: BaseTest{ - Name: "Config", - Method: http.MethodGet, - URL: "/api/config", - TestAppFactory: AppFactory(baseApp), + Name: "Config", + Method: http.MethodGet, + URL: "/api/config", }, - authBasedExpectations: []AuthBasedExpectation{ + userTests: []UserTest{ { Name: "Unauthorized", ExpectedStatus: http.StatusOK, ExpectedContent: []string{ - `"flags":null`, + `"flags":[]`, }, }, { Name: "Analyst", - RequestHeaders: map[string]string{"Authorization": analystToken}, + AuthRecord: analystEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ - `"flags":null`, + `"flags":[]`, }, }, { Name: "Admin", - RequestHeaders: map[string]string{"Authorization": adminToken}, + Admin: adminEmail, ExpectedStatus: http.StatusOK, ExpectedContent: []string{ - `"flags":null`, + `"flags":[]`, }, }, }, @@ -70,9 +67,14 @@ func Test_Routes(t *testing.T) { } for _, testSet := range testSets { t.Run(testSet.baseTest.Name, func(t *testing.T) { - for _, authBasedExpectation := range testSet.authBasedExpectations { - scenario := mergeScenario(testSet.baseTest, authBasedExpectation) - scenario.Test(t) + t.Parallel() + + for _, userTest := range testSet.userTests { + t.Run(userTest.Name, func(t *testing.T) { + t.Parallel() + + runMatrixTest(t, testSet.baseTest, userTest) + }) } }) } diff --git a/testing/testapp.go b/testing/testapp.go new file mode 100644 index 0000000..3a3651a --- /dev/null +++ b/testing/testapp.go @@ -0,0 +1,160 @@ +package testing + +import ( + "fmt" + "os" + "testing" + + "github.com/pocketbase/pocketbase" + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tokens" + + "github.com/SecurityBrewery/catalyst/app" + "github.com/SecurityBrewery/catalyst/migrations" +) + +func App(t *testing.T) (*pocketbase.PocketBase, *Counter, func()) { + t.Helper() + + temp, err := os.MkdirTemp("", "catalyst_test_data") + if err != nil { + t.Fatal(err) + } + + baseApp, err := app.App(temp, true) + if err != nil { + t.Fatal(err) + } + + baseApp.Settings().Logs.MaxDays = 0 + + defaultTestData(t, baseApp) + + counter := countEvents(baseApp) + + return baseApp, counter, func() { _ = os.RemoveAll(temp) } +} + +func generateAdminToken(t *testing.T, baseApp core.App, email string) (string, error) { + t.Helper() + + admin, err := baseApp.Dao().FindAdminByEmail(email) + if err != nil { + return "", fmt.Errorf("failed to find admin: %w", err) + } + + return tokens.NewAdminAuthToken(baseApp, admin) +} + +func generateRecordToken(t *testing.T, baseApp core.App, email string) (string, error) { + t.Helper() + + record, err := baseApp.Dao().FindAuthRecordByEmail(migrations.UserCollectionName, email) + if err != nil { + return "", fmt.Errorf("failed to find record: %w", err) + } + + return tokens.NewRecordAuthToken(baseApp, record) +} + +func countEvents(t *pocketbase.PocketBase) *Counter { + c := NewCounter() + + t.OnBeforeApiError().Add(count[*core.ApiErrorEvent](c, "OnBeforeApiError")) + t.OnBeforeApiError().Add(count[*core.ApiErrorEvent](c, "OnBeforeApiError")) + t.OnAfterApiError().Add(count[*core.ApiErrorEvent](c, "OnAfterApiError")) + t.OnModelBeforeCreate().Add(count[*core.ModelEvent](c, "OnModelBeforeCreate")) + t.OnModelAfterCreate().Add(count[*core.ModelEvent](c, "OnModelAfterCreate")) + t.OnModelBeforeUpdate().Add(count[*core.ModelEvent](c, "OnModelBeforeUpdate")) + t.OnModelAfterUpdate().Add(count[*core.ModelEvent](c, "OnModelAfterUpdate")) + t.OnModelBeforeDelete().Add(count[*core.ModelEvent](c, "OnModelBeforeDelete")) + t.OnModelAfterDelete().Add(count[*core.ModelEvent](c, "OnModelAfterDelete")) + t.OnRecordsListRequest().Add(count[*core.RecordsListEvent](c, "OnRecordsListRequest")) + t.OnRecordViewRequest().Add(count[*core.RecordViewEvent](c, "OnRecordViewRequest")) + t.OnRecordBeforeCreateRequest().Add(count[*core.RecordCreateEvent](c, "OnRecordBeforeCreateRequest")) + t.OnRecordAfterCreateRequest().Add(count[*core.RecordCreateEvent](c, "OnRecordAfterCreateRequest")) + t.OnRecordBeforeUpdateRequest().Add(count[*core.RecordUpdateEvent](c, "OnRecordBeforeUpdateRequest")) + t.OnRecordAfterUpdateRequest().Add(count[*core.RecordUpdateEvent](c, "OnRecordAfterUpdateRequest")) + t.OnRecordBeforeDeleteRequest().Add(count[*core.RecordDeleteEvent](c, "OnRecordBeforeDeleteRequest")) + t.OnRecordAfterDeleteRequest().Add(count[*core.RecordDeleteEvent](c, "OnRecordAfterDeleteRequest")) + t.OnRecordAuthRequest().Add(count[*core.RecordAuthEvent](c, "OnRecordAuthRequest")) + t.OnRecordBeforeAuthWithPasswordRequest().Add(count[*core.RecordAuthWithPasswordEvent](c, "OnRecordBeforeAuthWithPasswordRequest")) + t.OnRecordAfterAuthWithPasswordRequest().Add(count[*core.RecordAuthWithPasswordEvent](c, "OnRecordAfterAuthWithPasswordRequest")) + t.OnRecordBeforeAuthWithOAuth2Request().Add(count[*core.RecordAuthWithOAuth2Event](c, "OnRecordBeforeAuthWithOAuth2Request")) + t.OnRecordAfterAuthWithOAuth2Request().Add(count[*core.RecordAuthWithOAuth2Event](c, "OnRecordAfterAuthWithOAuth2Request")) + t.OnRecordBeforeAuthRefreshRequest().Add(count[*core.RecordAuthRefreshEvent](c, "OnRecordBeforeAuthRefreshRequest")) + t.OnRecordAfterAuthRefreshRequest().Add(count[*core.RecordAuthRefreshEvent](c, "OnRecordAfterAuthRefreshRequest")) + t.OnRecordBeforeRequestPasswordResetRequest().Add(count[*core.RecordRequestPasswordResetEvent](c, "OnRecordBeforeRequestPasswordResetRequest")) + t.OnRecordAfterRequestPasswordResetRequest().Add(count[*core.RecordRequestPasswordResetEvent](c, "OnRecordAfterRequestPasswordResetRequest")) + t.OnRecordBeforeConfirmPasswordResetRequest().Add(count[*core.RecordConfirmPasswordResetEvent](c, "OnRecordBeforeConfirmPasswordResetRequest")) + t.OnRecordAfterConfirmPasswordResetRequest().Add(count[*core.RecordConfirmPasswordResetEvent](c, "OnRecordAfterConfirmPasswordResetRequest")) + t.OnRecordBeforeRequestVerificationRequest().Add(count[*core.RecordRequestVerificationEvent](c, "OnRecordBeforeRequestVerificationRequest")) + t.OnRecordAfterRequestVerificationRequest().Add(count[*core.RecordRequestVerificationEvent](c, "OnRecordAfterRequestVerificationRequest")) + t.OnRecordBeforeConfirmVerificationRequest().Add(count[*core.RecordConfirmVerificationEvent](c, "OnRecordBeforeConfirmVerificationRequest")) + t.OnRecordAfterConfirmVerificationRequest().Add(count[*core.RecordConfirmVerificationEvent](c, "OnRecordAfterConfirmVerificationRequest")) + t.OnRecordBeforeRequestEmailChangeRequest().Add(count[*core.RecordRequestEmailChangeEvent](c, "OnRecordBeforeRequestEmailChangeRequest")) + t.OnRecordAfterRequestEmailChangeRequest().Add(count[*core.RecordRequestEmailChangeEvent](c, "OnRecordAfterRequestEmailChangeRequest")) + t.OnRecordBeforeConfirmEmailChangeRequest().Add(count[*core.RecordConfirmEmailChangeEvent](c, "OnRecordBeforeConfirmEmailChangeRequest")) + t.OnRecordAfterConfirmEmailChangeRequest().Add(count[*core.RecordConfirmEmailChangeEvent](c, "OnRecordAfterConfirmEmailChangeRequest")) + t.OnRecordListExternalAuthsRequest().Add(count[*core.RecordListExternalAuthsEvent](c, "OnRecordListExternalAuthsRequest")) + t.OnRecordBeforeUnlinkExternalAuthRequest().Add(count[*core.RecordUnlinkExternalAuthEvent](c, "OnRecordBeforeUnlinkExternalAuthRequest")) + t.OnRecordAfterUnlinkExternalAuthRequest().Add(count[*core.RecordUnlinkExternalAuthEvent](c, "OnRecordAfterUnlinkExternalAuthRequest")) + t.OnMailerBeforeAdminResetPasswordSend().Add(count[*core.MailerAdminEvent](c, "OnMailerBeforeAdminResetPasswordSend")) + t.OnMailerAfterAdminResetPasswordSend().Add(count[*core.MailerAdminEvent](c, "OnMailerAfterAdminResetPasswordSend")) + t.OnMailerBeforeRecordResetPasswordSend().Add(count[*core.MailerRecordEvent](c, "OnMailerBeforeRecordResetPasswordSend")) + t.OnMailerAfterRecordResetPasswordSend().Add(count[*core.MailerRecordEvent](c, "OnMailerAfterRecordResetPasswordSend")) + t.OnMailerBeforeRecordVerificationSend().Add(count[*core.MailerRecordEvent](c, "OnMailerBeforeRecordVerificationSend")) + t.OnMailerAfterRecordVerificationSend().Add(count[*core.MailerRecordEvent](c, "OnMailerAfterRecordVerificationSend")) + t.OnMailerBeforeRecordChangeEmailSend().Add(count[*core.MailerRecordEvent](c, "OnMailerBeforeRecordChangeEmailSend")) + t.OnMailerAfterRecordChangeEmailSend().Add(count[*core.MailerRecordEvent](c, "OnMailerAfterRecordChangeEmailSend")) + t.OnRealtimeConnectRequest().Add(count[*core.RealtimeConnectEvent](c, "OnRealtimeConnectRequest")) + t.OnRealtimeDisconnectRequest().Add(count[*core.RealtimeDisconnectEvent](c, "OnRealtimeDisconnectRequest")) + t.OnRealtimeBeforeMessageSend().Add(count[*core.RealtimeMessageEvent](c, "OnRealtimeBeforeMessageSend")) + t.OnRealtimeAfterMessageSend().Add(count[*core.RealtimeMessageEvent](c, "OnRealtimeAfterMessageSend")) + t.OnRealtimeBeforeSubscribeRequest().Add(count[*core.RealtimeSubscribeEvent](c, "OnRealtimeBeforeSubscribeRequest")) + t.OnRealtimeAfterSubscribeRequest().Add(count[*core.RealtimeSubscribeEvent](c, "OnRealtimeAfterSubscribeRequest")) + t.OnSettingsListRequest().Add(count[*core.SettingsListEvent](c, "OnSettingsListRequest")) + t.OnSettingsBeforeUpdateRequest().Add(count[*core.SettingsUpdateEvent](c, "OnSettingsBeforeUpdateRequest")) + t.OnSettingsAfterUpdateRequest().Add(count[*core.SettingsUpdateEvent](c, "OnSettingsAfterUpdateRequest")) + t.OnCollectionsListRequest().Add(count[*core.CollectionsListEvent](c, "OnCollectionsListRequest")) + t.OnCollectionViewRequest().Add(count[*core.CollectionViewEvent](c, "OnCollectionViewRequest")) + t.OnCollectionBeforeCreateRequest().Add(count[*core.CollectionCreateEvent](c, "OnCollectionBeforeCreateRequest")) + t.OnCollectionAfterCreateRequest().Add(count[*core.CollectionCreateEvent](c, "OnCollectionAfterCreateRequest")) + t.OnCollectionBeforeUpdateRequest().Add(count[*core.CollectionUpdateEvent](c, "OnCollectionBeforeUpdateRequest")) + t.OnCollectionAfterUpdateRequest().Add(count[*core.CollectionUpdateEvent](c, "OnCollectionAfterUpdateRequest")) + t.OnCollectionBeforeDeleteRequest().Add(count[*core.CollectionDeleteEvent](c, "OnCollectionBeforeDeleteRequest")) + t.OnCollectionAfterDeleteRequest().Add(count[*core.CollectionDeleteEvent](c, "OnCollectionAfterDeleteRequest")) + t.OnCollectionsBeforeImportRequest().Add(count[*core.CollectionsImportEvent](c, "OnCollectionsBeforeImportRequest")) + t.OnCollectionsAfterImportRequest().Add(count[*core.CollectionsImportEvent](c, "OnCollectionsAfterImportRequest")) + t.OnAdminsListRequest().Add(count[*core.AdminsListEvent](c, "OnAdminsListRequest")) + t.OnAdminViewRequest().Add(count[*core.AdminViewEvent](c, "OnAdminViewRequest")) + t.OnAdminBeforeCreateRequest().Add(count[*core.AdminCreateEvent](c, "OnAdminBeforeCreateRequest")) + t.OnAdminAfterCreateRequest().Add(count[*core.AdminCreateEvent](c, "OnAdminAfterCreateRequest")) + t.OnAdminBeforeUpdateRequest().Add(count[*core.AdminUpdateEvent](c, "OnAdminBeforeUpdateRequest")) + t.OnAdminAfterUpdateRequest().Add(count[*core.AdminUpdateEvent](c, "OnAdminAfterUpdateRequest")) + t.OnAdminBeforeDeleteRequest().Add(count[*core.AdminDeleteEvent](c, "OnAdminBeforeDeleteRequest")) + t.OnAdminAfterDeleteRequest().Add(count[*core.AdminDeleteEvent](c, "OnAdminAfterDeleteRequest")) + t.OnAdminAuthRequest().Add(count[*core.AdminAuthEvent](c, "OnAdminAuthRequest")) + t.OnAdminBeforeAuthWithPasswordRequest().Add(count[*core.AdminAuthWithPasswordEvent](c, "OnAdminBeforeAuthWithPasswordRequest")) + t.OnAdminAfterAuthWithPasswordRequest().Add(count[*core.AdminAuthWithPasswordEvent](c, "OnAdminAfterAuthWithPasswordRequest")) + t.OnAdminBeforeAuthRefreshRequest().Add(count[*core.AdminAuthRefreshEvent](c, "OnAdminBeforeAuthRefreshRequest")) + t.OnAdminAfterAuthRefreshRequest().Add(count[*core.AdminAuthRefreshEvent](c, "OnAdminAfterAuthRefreshRequest")) + t.OnAdminBeforeRequestPasswordResetRequest().Add(count[*core.AdminRequestPasswordResetEvent](c, "OnAdminBeforeRequestPasswordResetRequest")) + t.OnAdminAfterRequestPasswordResetRequest().Add(count[*core.AdminRequestPasswordResetEvent](c, "OnAdminAfterRequestPasswordResetRequest")) + t.OnAdminBeforeConfirmPasswordResetRequest().Add(count[*core.AdminConfirmPasswordResetEvent](c, "OnAdminBeforeConfirmPasswordResetRequest")) + t.OnAdminAfterConfirmPasswordResetRequest().Add(count[*core.AdminConfirmPasswordResetEvent](c, "OnAdminAfterConfirmPasswordResetRequest")) + t.OnFileDownloadRequest().Add(count[*core.FileDownloadEvent](c, "OnFileDownloadRequest")) + t.OnFileBeforeTokenRequest().Add(count[*core.FileTokenEvent](c, "OnFileBeforeTokenRequest")) + t.OnFileAfterTokenRequest().Add(count[*core.FileTokenEvent](c, "OnFileAfterTokenRequest")) + t.OnFileAfterTokenRequest().Add(count[*core.FileTokenEvent](c, "OnFileAfterTokenRequest")) + + return c +} + +func count[T any](c *Counter, name string) func(_ T) error { + return func(_ T) error { + c.Increment(name) + + return nil + } +} diff --git a/testing/testdata.go b/testing/testdata.go index 55431f5..57032ec 100644 --- a/testing/testdata.go +++ b/testing/testdata.go @@ -19,6 +19,7 @@ func defaultTestData(t *testing.T, app core.App) { adminTestData(t, app) userTestData(t, app) + ticketTestData(t, app) reactionTestData(t, app) } @@ -57,6 +58,30 @@ func userTestData(t *testing.T, app core.App) { } } +func ticketTestData(t *testing.T, app core.App) { + t.Helper() + + collection, err := app.Dao().FindCollectionByNameOrId(migrations.TicketCollectionName) + if err != nil { + t.Fatal(err) + } + + record := models.NewRecord(collection) + record.SetId("t_test") + + record.Set("name", "Test Ticket") + record.Set("type", "incident") + record.Set("description", "This is a test ticket.") + record.Set("open", true) + record.Set("schema", `{"type":"object","properties":{"tlp":{"title":"TLP","type":"string"}}}`) + record.Set("state", `{"tlp":"AMBER"}`) + record.Set("owner", "u_bob_analyst") + + if err := app.Dao().SaveRecord(record); err != nil { + t.Fatal(err) + } +} + func reactionTestData(t *testing.T, app core.App) { t.Helper() @@ -69,9 +94,9 @@ func reactionTestData(t *testing.T, app core.App) { record.SetId("r_reaction") record.Set("name", "Reaction") record.Set("trigger", "webhook") - record.Set("triggerdata", `{"path":"test"}`) + record.Set("triggerdata", `{"token":"1234567890","path":"test"}`) record.Set("action", "python") - record.Set("actiondata", `{"bootstrap":"requests","script":"print('Hello, World!')"}`) + record.Set("actiondata", `{"requirements":"requests","script":"print('Hello, World!')"}`) if err := app.Dao().SaveRecord(record); err != nil { t.Fatal(err) @@ -95,7 +120,7 @@ func reactionTestData(t *testing.T, app core.App) { record.Set("trigger", "hook") record.Set("triggerdata", `{"collections":["tickets"],"events":["create"]}`) record.Set("action", "python") - record.Set("actiondata", `{"bootstrap":"requests","script":"import requests\nrequests.post('http://127.0.0.1:12346/test', json={'test':True})"}`) + record.Set("actiondata", `{"requirements":"requests","script":"import requests\nrequests.post('http://127.0.0.1:12346/test', json={'test':True})"}`) if err := app.Dao().SaveRecord(record); err != nil { t.Fatal(err) diff --git a/testing/testing.go b/testing/testing.go index 2264623..5b0e9e6 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -3,162 +3,95 @@ package testing import ( "bytes" "encoding/json" - "os" + "fmt" + "net/http/httptest" "testing" + "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" - "github.com/pocketbase/pocketbase/tests" - "github.com/pocketbase/pocketbase/tokens" - - "github.com/SecurityBrewery/catalyst/app" - "github.com/SecurityBrewery/catalyst/migrations" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func BaseApp(t *testing.T) (core.App, string, string, func()) { - t.Helper() - - temp, err := os.MkdirTemp("", "catalyst_test_data") - if err != nil { - t.Fatal(err) - } - - baseApp := app.App(temp) - - if err := app.Bootstrap(baseApp); err != nil { - t.Fatal(err) - } - - defaultTestData(t, baseApp) - - adminToken, err := generateAdminToken(t, baseApp, adminEmail) - if err != nil { - t.Fatal(err) - } - - analystToken, err := generateRecordToken(t, baseApp, analystEmail) - if err != nil { - t.Fatal(err) - } - - return baseApp, adminToken, analystToken, func() { _ = os.RemoveAll(temp) } -} - -func AppFactory(baseApp core.App) func(t *testing.T) *tests.TestApp { - return func(t *testing.T) *tests.TestApp { - t.Helper() - - testApp, err := tests.NewTestApp(baseApp.DataDir()) - if err != nil { - t.Fatal(err) - } - - app.BindHooks(testApp) - - if err := app.Bootstrap(testApp); err != nil { - t.Fatal(err) - } - - return testApp - } -} - -func App(t *testing.T) (*tests.TestApp, func()) { - t.Helper() - - baseApp, _, _, cleanup := BaseApp(t) - - testApp := AppFactory(baseApp)(t) - - return testApp, cleanup -} - -func generateAdminToken(t *testing.T, baseApp core.App, email string) (string, error) { - t.Helper() - - app, err := tests.NewTestApp(baseApp.DataDir()) - if err != nil { - return "", err - } - defer app.Cleanup() - - admin, err := app.Dao().FindAdminByEmail(email) - if err != nil { - return "", err - } - - return tokens.NewAdminAuthToken(app, admin) -} - -func generateRecordToken(t *testing.T, baseApp core.App, email string) (string, error) { - t.Helper() - - app, err := tests.NewTestApp(baseApp.DataDir()) - if err != nil { - t.Fatal(err) - } - defer app.Cleanup() - - record, err := app.Dao().FindAuthRecordByEmail(migrations.UserCollectionName, email) - if err != nil { - return "", err - } - - return tokens.NewRecordAuthToken(app, record) -} - type BaseTest struct { Name string Method string RequestHeaders map[string]string URL string Body string - TestAppFactory func(t *testing.T) *tests.TestApp } -type AuthBasedExpectation struct { +type UserTest struct { Name string - RequestHeaders map[string]string + AuthRecord string + Admin string ExpectedStatus int ExpectedContent []string NotExpectedContent []string ExpectedEvents map[string]int } -type authMatrixText struct { - baseTest BaseTest - authBasedExpectations []AuthBasedExpectation +type catalystTest struct { + baseTest BaseTest + userTests []UserTest } -func mergeScenario(base BaseTest, expectation AuthBasedExpectation) tests.ApiScenario { - return tests.ApiScenario{ - Name: expectation.Name, - Method: base.Method, - Url: base.URL, - Body: bytes.NewBufferString(base.Body), - TestAppFactory: base.TestAppFactory, +func runMatrixTest(t *testing.T, baseTest BaseTest, userTest UserTest) { + t.Helper() - RequestHeaders: mergeMaps(base.RequestHeaders, expectation.RequestHeaders), - ExpectedStatus: expectation.ExpectedStatus, - ExpectedContent: expectation.ExpectedContent, - NotExpectedContent: expectation.NotExpectedContent, - ExpectedEvents: expectation.ExpectedEvents, - } -} + baseApp, counter, baseAppCleanup := App(t) + defer baseAppCleanup() -func mergeMaps(a, b map[string]string) map[string]string { - if a == nil { - return b + server, err := apis.InitApi(baseApp) + require.NoError(t, err) + + if err := baseApp.OnBeforeServe().Trigger(&core.ServeEvent{ + App: baseApp, + Router: server, + }); err != nil { + t.Fatal(fmt.Errorf("failed to trigger OnBeforeServe: %w", err)) } - if b == nil { - return a + recorder := httptest.NewRecorder() + body := bytes.NewBufferString(baseTest.Body) + req := httptest.NewRequest(baseTest.Method, baseTest.URL, body) + + for k, v := range baseTest.RequestHeaders { + req.Header.Set(k, v) } - for k, v := range b { - a[k] = v + if userTest.AuthRecord != "" { + token, err := generateRecordToken(t, baseApp, userTest.AuthRecord) + require.NoError(t, err) + + req.Header.Set("Authorization", token) } - return a + if userTest.Admin != "" { + token, err := generateAdminToken(t, baseApp, userTest.Admin) + require.NoError(t, err) + + req.Header.Set("Authorization", token) + } + + server.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + assert.Equal(t, userTest.ExpectedStatus, res.StatusCode) + + for _, expectedContent := range userTest.ExpectedContent { + assert.Contains(t, recorder.Body.String(), expectedContent) + } + + for _, notExpectedContent := range userTest.NotExpectedContent { + assert.NotContains(t, recorder.Body.String(), notExpectedContent) + } + + for event, count := range userTest.ExpectedEvents { + assert.Equal(t, count, counter.Count(event)) + } } func b(data map[string]any) []byte { diff --git a/ui/src/components/reaction/ReactionDisplay.vue b/ui/src/components/reaction/ReactionDisplay.vue index 869515d..3e05bbc 100644 --- a/ui/src/components/reaction/ReactionDisplay.vue +++ b/ui/src/components/reaction/ReactionDisplay.vue @@ -4,13 +4,17 @@ import DeleteDialog from '@/components/common/DeleteDialog.vue' import ReactionForm from '@/components/reaction/ReactionForm.vue' import { ScrollArea } from '@/components/ui/scroll-area' import { Separator } from '@/components/ui/separator' +import { toast } from '@/components/ui/toast' import { useMutation, useQuery, useQueryClient } from '@tanstack/vue-query' +import { onMounted, onUnmounted } from 'vue' +import { useRouter } from 'vue-router' import { pb } from '@/lib/pocketbase' import type { Reaction } from '@/lib/types' import { handleError } from '@/lib/utils' +const router = useRouter() const queryClient = useQueryClient() const props = defineProps<{ @@ -32,6 +36,35 @@ const updateReactionMutation = useMutation({ onSuccess: () => queryClient.invalidateQueries({ queryKey: ['reactions'] }), onError: handleError }) + +onMounted(() => { + pb.collection('reactions').subscribe(props.id, (data) => { + if (data.action === 'delete') { + toast({ + title: 'Reaction deleted', + description: 'The reaction has been deleted.', + variant: 'destructive' + }) + + router.push({ name: 'reactions' }) + + return + } + + if (data.action === 'update') { + toast({ + title: 'Reaction updated', + description: 'The reaction has been updated.' + }) + + queryClient.invalidateQueries({ queryKey: ['reactions', props.id] }) + } + }) +}) + +onUnmounted(() => { + pb.collection('reactions').unsubscribe(props.id) +}) diff --git a/ui/src/components/reaction/ReactionList.vue b/ui/src/components/reaction/ReactionList.vue index b1f3ee5..044aa73 100644 --- a/ui/src/components/reaction/ReactionList.vue +++ b/ui/src/components/reaction/ReactionList.vue @@ -1,16 +1,19 @@