feat: improve python actions (#1083)

This commit is contained in:
Jonas Plum
2024-07-21 02:56:43 +02:00
committed by GitHub
parent 81bfbb2072
commit 91429effe2
55 changed files with 1143 additions and 585 deletions

View File

@@ -14,7 +14,11 @@ jobs:
with: { go-version: '1.22' } with: { go-version: '1.22' }
- uses: oven-sh/setup-bun@v1 - uses: oven-sh/setup-bun@v1
- run: make build-ui - run: |
bun install
mkdir -p dist
touch dist/index.html
working-directory: ui
- run: make install - run: make install
- run: make fmt - run: make fmt
@@ -28,13 +32,25 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: { go-version: '1.22' } with: { go-version: '1.22' }
- uses: oven-sh/setup-bun@v1
- run: make build-ui - run: |
mkdir -p ui/dist
touch ui/dist/index.html
- uses: golangci/golangci-lint-action@v6 - uses: golangci/golangci-lint-action@v6
with: { version: 'v1.59' } with: { version: 'v1.59' }
build:
name: Build
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with: { go-version: '1.22' }
- uses: oven-sh/setup-bun@v1
- run: make build-ui
test: test:
name: Test name: Test
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -44,9 +60,9 @@ jobs:
with: { go-version: '1.22' } with: { go-version: '1.22' }
- uses: oven-sh/setup-bun@v1 - uses: oven-sh/setup-bun@v1
- run: make build-ui - run: |
mkdir -p ui/dist
- run: make test touch ui/dist/index.html
- run: make test-coverage - run: make test-coverage

View File

@@ -5,39 +5,20 @@ linters:
enable-all: true enable-all: true
disable: disable:
# complexity # complexity
- cyclop
- gocognit
- gocyclo
- maintidx - maintidx
- nestif - funlen
# disable # disable
- bodyclose
- depguard - depguard
- dupl
- err113 - err113
- execinquery
- exhaustruct - exhaustruct
- funlen
- gochecknoglobals
- gochecknoinits
- goconst
- godox
- gomnd - gomnd
- gomoddirectives
- ireturn - ireturn
- lll - lll
- makezero
- mnd - mnd
- paralleltest
- perfsprint
- prealloc
- tagalign
- tagliatelle
- testpackage - testpackage
- varnamelen - varnamelen
- wrapcheck - wrapcheck
- wsl
linters-settings: linters-settings:
gci: gci:
sections: sections:

View File

@@ -43,7 +43,6 @@ build-ui:
dev: dev:
@echo "Running..." @echo "Running..."
rm -rf catalyst_data rm -rf catalyst_data
go run . bootstrap
go run . admin create admin@catalyst-soar.com 1234567890 go run . admin create admin@catalyst-soar.com 1234567890
go run . set-feature-flags dev go run . set-feature-flags dev
go run . fake-data go run . fake-data

View File

@@ -5,38 +5,40 @@ import (
"strings" "strings"
"github.com/pocketbase/pocketbase" "github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/core"
"github.com/SecurityBrewery/catalyst/migrations" "github.com/SecurityBrewery/catalyst/migrations"
"github.com/SecurityBrewery/catalyst/reaction" "github.com/SecurityBrewery/catalyst/reaction"
"github.com/SecurityBrewery/catalyst/webhook" "github.com/SecurityBrewery/catalyst/webhook"
) )
func init() { func init() { //nolint:gochecknoinits
migrations.Register() migrations.Register()
} }
func App(dir string) *pocketbase.PocketBase { func App(dir string, test bool) (*pocketbase.PocketBase, error) {
app := pocketbase.NewWithConfig(pocketbase.Config{ app := pocketbase.NewWithConfig(pocketbase.Config{
DefaultDev: dev(), DefaultDev: test || dev(),
DefaultDataDir: dir, DefaultDataDir: dir,
}) })
BindHooks(app) webhook.BindHooks(app)
reaction.BindHooks(app, test)
app.OnBeforeServe().Add(addRoutes())
// Register additional commands // Register additional commands
app.RootCmd.AddCommand(bootstrapCmd(app))
app.RootCmd.AddCommand(fakeDataCmd(app)) app.RootCmd.AddCommand(fakeDataCmd(app))
app.RootCmd.AddCommand(setFeatureFlagsCmd(app)) app.RootCmd.AddCommand(setFeatureFlagsCmd(app))
return app if err := app.Bootstrap(); err != nil {
return nil, err
} }
func BindHooks(app core.App) { if err := MigrateDBs(app); err != nil {
webhook.BindHooks(app) return nil, err
reaction.BindHooks(app) }
app.OnBeforeServe().Add(addRoutes()) return app, nil
} }
func dev() bool { func dev() bool {

View File

@@ -1,25 +0,0 @@
package app
import (
"github.com/pocketbase/pocketbase/core"
"github.com/spf13/cobra"
)
func Bootstrap(app core.App) error {
if err := app.Bootstrap(); err != nil {
return err
}
return MigrateDBs(app)
}
func bootstrapCmd(app core.App) *cobra.Command {
return &cobra.Command{
Use: "bootstrap",
Run: func(_ *cobra.Command, _ []string) {
if err := Bootstrap(app); err != nil {
app.Logger().Error(err.Error())
}
},
}
}

View File

@@ -16,7 +16,7 @@ func Flags(app core.App) ([]string, error) {
return nil, err return nil, err
} }
var flags []string flags := make([]string, 0, len(records))
for _, r := range records { for _, r := range records {
flags = append(flags, r.GetString("name")) flags = append(flags, r.GetString("name"))
@@ -36,7 +36,7 @@ func SetFlags(app core.App, args []string) error {
return err return err
} }
var existingFlags []string var existingFlags []string //nolint:prealloc
for _, featureRecord := range featureRecords { for _, featureRecord := range featureRecords {
// remove feature flags that are not in the args // remove feature flags that are not in the args

View File

@@ -11,7 +11,9 @@ import (
) )
func Test_flags(t *testing.T) { func Test_flags(t *testing.T) {
catalystApp, cleanup := catalystTesting.App(t) t.Parallel()
catalystApp, _, cleanup := catalystTesting.App(t)
defer cleanup() defer cleanup()
got, err := app.Flags(catalystApp) got, err := app.Flags(catalystApp)
@@ -22,9 +24,12 @@ func Test_flags(t *testing.T) {
} }
func Test_setFlags(t *testing.T) { func Test_setFlags(t *testing.T) {
catalystApp, cleanup := catalystTesting.App(t) t.Parallel()
catalystApp, _, cleanup := catalystTesting.App(t)
defer cleanup() defer cleanup()
// stage 1
require.NoError(t, app.SetFlags(catalystApp, []string{"test"})) require.NoError(t, app.SetFlags(catalystApp, []string{"test"}))
got, err := app.Flags(catalystApp) got, err := app.Flags(catalystApp)
@@ -32,10 +37,19 @@ func Test_setFlags(t *testing.T) {
assert.ElementsMatch(t, []string{"test"}, got) assert.ElementsMatch(t, []string{"test"}, got)
// stage 2
require.NoError(t, app.SetFlags(catalystApp, []string{"test2"})) require.NoError(t, app.SetFlags(catalystApp, []string{"test2"}))
got, err = app.Flags(catalystApp) got, err = app.Flags(catalystApp)
require.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, []string{"test2"}, got) assert.ElementsMatch(t, []string{"test2"}, got)
// stage 3
require.NoError(t, app.SetFlags(catalystApp, []string{"test", "test2"}))
got, err = app.Flags(catalystApp)
require.NoError(t, err)
assert.ElementsMatch(t, []string{"test", "test2"}, got)
} }

View File

@@ -33,13 +33,13 @@ func MigrateDBs(app core.App) error {
return nil return nil
} }
func isIgnored(err error) bool {
// this fix ignores some errors that come from upstream migrations. // this fix ignores some errors that come from upstream migrations.
var ignoreErrors = []string{ ignoreErrors := []string{
"1673167670_multi_match_migrate", "1673167670_multi_match_migrate",
"1660821103_add_user_ip_column", "1660821103_add_user_ip_column",
} }
func isIgnored(err error) bool {
for _, ignore := range ignoreErrors { for _, ignore := range ignoreErrors {
if strings.Contains(err.Error(), ignore) { if strings.Contains(err.Error(), ignore) {
return true return true

View File

@@ -0,0 +1,39 @@
package app
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_isIgnored(t *testing.T) {
t.Parallel()
type args struct {
err error
}
tests := []struct {
name string
args args
want bool
}{
{
name: "error is ignored",
args: args{err: errors.New("1673167670_multi_match_migrate")},
want: true,
},
{
name: "error is not ignored",
args: args{err: errors.New("1673167670_multi_match")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equalf(t, tt.want, isIgnored(tt.args.err), "isIgnored(%v)", tt.args.err)
})
}
}

View File

@@ -11,7 +11,9 @@ import (
) )
func Test_MigrateDBsDown(t *testing.T) { func Test_MigrateDBsDown(t *testing.T) {
catalystApp, cleanup := catalystTesting.App(t) t.Parallel()
catalystApp, _, cleanup := catalystTesting.App(t)
defer cleanup() defer cleanup()
_, err := catalystApp.Dao().FindCollectionByNameOrId(migrations.ReactionCollectionName) _, err := catalystApp.Dao().FindCollectionByNameOrId(migrations.ReactionCollectionName)

View File

@@ -38,11 +38,10 @@ func staticFiles() func(echo.Context) error {
return func(c echo.Context) error { return func(c echo.Context) error {
if dev() { if dev() {
u, _ := url.Parse("http://localhost:3000/") u, _ := url.Parse("http://localhost:3000/")
proxy := httputil.NewSingleHostReverseProxy(u)
c.Request().Host = c.Request().URL.Host c.Request().Host = c.Request().URL.Host
proxy.ServeHTTP(c.Response(), c.Request()) httputil.NewSingleHostReverseProxy(u).ServeHTTP(c.Response(), c.Request())
return nil return nil
} }

21
app/routes_test.go Normal file
View File

@@ -0,0 +1,21 @@
package app
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/require"
)
func Test_staticFiles(t *testing.T) {
t.Parallel()
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
require.NoError(t, staticFiles()(c))
}

View File

@@ -1,6 +1,7 @@
package fakedata package fakedata
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -50,14 +51,12 @@ func Records(app core.App, userCount int, ticketCount int) ([]*models.Record, er
users := userRecords(app.Dao(), userCount) users := userRecords(app.Dao(), userCount)
tickets := ticketRecords(app.Dao(), users, types, ticketCount) tickets := ticketRecords(app.Dao(), users, types, ticketCount)
webhooks := webhookRecords(app.Dao())
reactions := reactionRecords(app.Dao()) reactions := reactionRecords(app.Dao())
var records []*models.Record var records []*models.Record
records = append(records, users...) records = append(records, users...)
records = append(records, types...) records = append(records, types...)
records = append(records, tickets...) records = append(records, tickets...)
records = append(records, webhooks...)
records = append(records, reactions...) records = append(records, reactions...)
return records, nil return records, nil
@@ -69,7 +68,7 @@ func userRecords(dao *daos.Dao, count int) []*models.Record {
panic(err) panic(err)
} }
var records []*models.Record records := make([]*models.Record, 0, count)
// create the test user // create the test user
if _, err := dao.FindRecordById(migrations.UserCollectionName, "u_test"); err != nil { if _, err := dao.FindRecordById(migrations.UserCollectionName, "u_test"); err != nil {
@@ -105,7 +104,7 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
panic(err) panic(err)
} }
var records []*models.Record records := make([]*models.Record, 0, count)
created := time.Now() created := time.Now()
number := gofakeit.Number(200*count, 300*count) number := gofakeit.Number(200*count, 300*count)
@@ -135,12 +134,24 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
records = append(records, record) records = append(records, record)
// Add comments // Add comments
for range gofakeit.IntN(5) { records = append(records, commentRecords(dao, users, created, record)...)
records = append(records, timelineRecords(dao, created, record)...)
records = append(records, taskRecords(dao, users, created, record)...)
records = append(records, linkRecords(dao, created, record)...)
}
return records
}
func commentRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record {
commentCollection, err := dao.FindCollectionByNameOrId(migrations.CommentCollectionName) commentCollection, err := dao.FindCollectionByNameOrId(migrations.CommentCollectionName)
if err != nil { if err != nil {
panic(err) panic(err)
} }
records := make([]*models.Record, 0, 5)
for range gofakeit.IntN(5) {
commentCreated := gofakeit.DateRange(created, time.Now()) commentCreated := gofakeit.DateRange(created, time.Now())
commentUpdated := gofakeit.DateRange(commentCreated, time.Now()) commentUpdated := gofakeit.DateRange(commentCreated, time.Now())
@@ -155,13 +166,18 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
records = append(records, commentRecord) records = append(records, commentRecord)
} }
// Add timeline return records
for range gofakeit.IntN(5) { }
func timelineRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record {
timelineCollection, err := dao.FindCollectionByNameOrId(migrations.TimelineCollectionName) timelineCollection, err := dao.FindCollectionByNameOrId(migrations.TimelineCollectionName)
if err != nil { if err != nil {
panic(err) panic(err)
} }
records := make([]*models.Record, 0, 5)
for range gofakeit.IntN(5) {
timelineCreated := gofakeit.DateRange(created, time.Now()) timelineCreated := gofakeit.DateRange(created, time.Now())
timelineUpdated := gofakeit.DateRange(timelineCreated, time.Now()) timelineUpdated := gofakeit.DateRange(timelineCreated, time.Now())
@@ -176,13 +192,18 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
records = append(records, timelineRecord) records = append(records, timelineRecord)
} }
// Add tasks return records
for range gofakeit.IntN(5) { }
func taskRecords(dao *daos.Dao, users []*models.Record, created time.Time, record *models.Record) []*models.Record {
taskCollection, err := dao.FindCollectionByNameOrId(migrations.TaskCollectionName) taskCollection, err := dao.FindCollectionByNameOrId(migrations.TaskCollectionName)
if err != nil { if err != nil {
panic(err) panic(err)
} }
records := make([]*models.Record, 0, 5)
for range gofakeit.IntN(5) {
taskCreated := gofakeit.DateRange(created, time.Now()) taskCreated := gofakeit.DateRange(created, time.Now())
taskUpdated := gofakeit.DateRange(taskCreated, time.Now()) taskUpdated := gofakeit.DateRange(taskCreated, time.Now())
@@ -198,13 +219,18 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
records = append(records, taskRecord) records = append(records, taskRecord)
} }
// Add links return records
for range gofakeit.IntN(5) { }
func linkRecords(dao *daos.Dao, created time.Time, record *models.Record) []*models.Record {
linkCollection, err := dao.FindCollectionByNameOrId(migrations.LinkCollectionName) linkCollection, err := dao.FindCollectionByNameOrId(migrations.LinkCollectionName)
if err != nil { if err != nil {
panic(err) panic(err)
} }
records := make([]*models.Record, 0, 5)
for range gofakeit.IntN(5) {
linkCreated := gofakeit.DateRange(created, time.Now()) linkCreated := gofakeit.DateRange(created, time.Now())
linkUpdated := gofakeit.DateRange(linkCreated, time.Now()) linkUpdated := gofakeit.DateRange(linkCreated, time.Now())
@@ -218,31 +244,58 @@ func ticketRecords(dao *daos.Dao, users, types []*models.Record, count int) []*m
records = append(records, linkRecord) records = append(records, linkRecord)
} }
}
return records return records
} }
func webhookRecords(dao *daos.Dao) []*models.Record { const alertIngestPy = `import sys
collection, err := dao.FindCollectionByNameOrId(migrations.WebhookCollectionName) import json
if err != nil { import random
panic(err) import os
}
record := models.NewRecord(collection) from pocketbase import PocketBase
record.SetId("w_" + security.PseudorandomString(10))
record.Set("name", "Test Webhook")
record.Set("collection", "tickets")
record.Set("destination", "http://localhost:8080/webhook")
return []*models.Record{record} # Parse the event from the webhook payload
} event = json.loads(sys.argv[1])
body = json.loads(event["body"])
# Connect to the PocketBase server
client = PocketBase('http://127.0.0.1:8090')
client.auth_store.save(token=os.environ["CATALYST_TOKEN"])
# Create a new ticket
client.collection("tickets").create({
"name": body["name"],
"type": "alert",
"open": True,
})`
const assignTicketsPy = `import sys
import json
import random
import os
from pocketbase import PocketBase
# Parse the ticket from the input
ticket = json.loads(sys.argv[1])
# Connect to the PocketBase server
client = PocketBase('http://127.0.0.1:8090')
client.auth_store.save(token=os.environ["CATALYST_TOKEN"])
# Get a random user
users = client.collection("users").get_list(1, 200)
random_user = random.choice(users.items)
# Assign the ticket to the random user
client.collection("tickets").update(ticket["record"]["id"], {
"owner": random_user.id,
})`
const ( const (
triggerWebhook = `{"token":"1234567890","path":"webhook"}` triggerWebhook = `{"token":"1234567890","path":"webhook"}`
reactionPython = `{"requirements":"requests","script":"import sys\n\nprint(sys.argv[1])"}` triggerHook = `{"collections":["tickets"],"events":["create"]}`
triggerHook = `{"collections":["tickets","comments"],"events":["create","update","delete"]}`
reactionWebhook = `{"headers":["Content-Type: application/json"],"url":"http://localhost:8080/webhook"}`
) )
func reactionRecords(dao *daos.Dao) []*models.Record { func reactionRecords(dao *daos.Dao) []*models.Record {
@@ -253,23 +306,39 @@ func reactionRecords(dao *daos.Dao) []*models.Record {
panic(err) panic(err)
} }
alertIngestActionData, err := json.Marshal(map[string]interface{}{
"requirements": "pocketbase",
"script": alertIngestPy,
})
if err != nil {
panic(err)
}
record := models.NewRecord(collection) record := models.NewRecord(collection)
record.SetId("w_" + security.PseudorandomString(10)) record.SetId("w_" + security.PseudorandomString(10))
record.Set("name", "Test Reaction") record.Set("name", "Test Reaction")
record.Set("trigger", "webhook") record.Set("trigger", "webhook")
record.Set("triggerdata", triggerWebhook) record.Set("triggerdata", triggerWebhook)
record.Set("action", "python") record.Set("action", "python")
record.Set("actiondata", reactionPython) record.Set("actiondata", string(alertIngestActionData))
records = append(records, record) records = append(records, record)
assignTicketsActionData, err := json.Marshal(map[string]interface{}{
"requirements": "pocketbase",
"script": assignTicketsPy,
})
if err != nil {
panic(err)
}
record = models.NewRecord(collection) record = models.NewRecord(collection)
record.SetId("w_" + security.PseudorandomString(10)) record.SetId("w_" + security.PseudorandomString(10))
record.Set("name", "Test Reaction 2") record.Set("name", "Test Reaction 2")
record.Set("trigger", "hook") record.Set("trigger", "hook")
record.Set("triggerdata", triggerHook) record.Set("triggerdata", triggerHook)
record.Set("action", "webhook") record.Set("action", "python")
record.Set("actiondata", reactionWebhook) record.Set("actiondata", string(assignTicketsActionData))
records = append(records, record) records = append(records, record)

View File

@@ -11,7 +11,9 @@ import (
) )
func Test_records(t *testing.T) { func Test_records(t *testing.T) {
app, cleanup := catalystTesting.App(t) t.Parallel()
app, _, cleanup := catalystTesting.App(t)
defer cleanup() defer cleanup()
got, err := fakedata.Records(app, 2, 2) got, err := fakedata.Records(app, 2, 2)
@@ -21,7 +23,9 @@ func Test_records(t *testing.T) {
} }
func TestGenerate(t *testing.T) { func TestGenerate(t *testing.T) {
app, cleanup := catalystTesting.App(t) t.Parallel()
app, _, cleanup := catalystTesting.App(t)
defer cleanup() defer cleanup()
err := fakedata.Generate(app, 0, 0) err := fakedata.Generate(app, 0, 0)

View File

@@ -7,22 +7,32 @@ import (
) )
func Test_fakeTicketComment(t *testing.T) { func Test_fakeTicketComment(t *testing.T) {
t.Parallel()
assert.NotEmpty(t, fakeTicketComment()) assert.NotEmpty(t, fakeTicketComment())
} }
func Test_fakeTicketDescription(t *testing.T) { func Test_fakeTicketDescription(t *testing.T) {
t.Parallel()
assert.NotEmpty(t, fakeTicketDescription()) assert.NotEmpty(t, fakeTicketDescription())
} }
func Test_fakeTicketTask(t *testing.T) { func Test_fakeTicketTask(t *testing.T) {
t.Parallel()
assert.NotEmpty(t, fakeTicketTask()) assert.NotEmpty(t, fakeTicketTask())
} }
func Test_fakeTicketTimelineMessage(t *testing.T) { func Test_fakeTicketTimelineMessage(t *testing.T) {
t.Parallel()
assert.NotEmpty(t, fakeTicketTimelineMessage()) assert.NotEmpty(t, fakeTicketTimelineMessage())
} }
func Test_random(t *testing.T) { func Test_random(t *testing.T) {
t.Parallel()
type args[T any] struct { type args[T any] struct {
e []T e []T
} }
@@ -40,6 +50,8 @@ func Test_random(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := random(tt.args.e) got := random(tt.args.e)
assert.Contains(t, tt.args.e, got) assert.Contains(t, tt.args.e, got)

2
go.mod
View File

@@ -4,6 +4,7 @@ go 1.22.1
require ( require (
github.com/brianvoe/gofakeit/v7 v7.0.3 github.com/brianvoe/gofakeit/v7 v7.0.3
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/labstack/echo/v5 v5.0.0-20230722203903-ec5b858dab61 github.com/labstack/echo/v5 v5.0.0-20230722203903-ec5b858dab61
github.com/pocketbase/dbx v1.10.1 github.com/pocketbase/dbx v1.10.1
github.com/pocketbase/pocketbase v0.22.14 github.com/pocketbase/pocketbase v0.22.14
@@ -46,7 +47,6 @@ require (
github.com/go-ozzo/ozzo-validation/v4 v4.3.0 // indirect github.com/go-ozzo/ozzo-validation/v4 v4.3.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.3 // indirect github.com/goccy/go-json v0.10.3 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect

View File

@@ -7,7 +7,12 @@ import (
) )
func main() { func main() {
if err := app.App("./catalyst_data").Start(); err != nil { catalyst, err := app.App("./catalyst_data", false)
if err != nil {
log.Fatal(err)
}
if err := catalyst.Start(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

View File

@@ -4,7 +4,6 @@ import (
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/security"
) )
func defaultDataUp(db dbx.Builder) error { func defaultDataUp(db dbx.Builder) error {
@@ -30,7 +29,7 @@ func typeRecords(dao *daos.Dao) []*models.Record {
var records []*models.Record var records []*models.Record
record := models.NewRecord(collection) record := models.NewRecord(collection)
record.SetId("y_" + security.PseudorandomString(5)) record.SetId("incident")
record.Set("singular", "Incident") record.Set("singular", "Incident")
record.Set("plural", "Incidents") record.Set("plural", "Incidents")
record.Set("icon", "Flame") record.Set("icon", "Flame")
@@ -39,7 +38,7 @@ func typeRecords(dao *daos.Dao) []*models.Record {
records = append(records, record) records = append(records, record)
record = models.NewRecord(collection) record = models.NewRecord(collection)
record.SetId("y_" + security.PseudorandomString(5)) record.SetId("alert")
record.Set("singular", "Alert") record.Set("singular", "Alert")
record.Set("plural", "Alerts") record.Set("plural", "Alerts")
record.Set("icon", "AlertTriangle") record.Set("icon", "AlertTriangle")

View File

@@ -21,9 +21,9 @@ func reactionsUp(db dbx.Builder) error {
Schema: schema.NewSchema( Schema: schema.NewSchema(
&schema.SchemaField{Name: "name", Type: schema.FieldTypeText, Required: true}, &schema.SchemaField{Name: "name", Type: schema.FieldTypeText, Required: true},
&schema.SchemaField{Name: "trigger", Type: schema.FieldTypeSelect, Required: true, Options: &schema.SelectOptions{MaxSelect: 1, Values: triggers}}, &schema.SchemaField{Name: "trigger", Type: schema.FieldTypeSelect, Required: true, Options: &schema.SelectOptions{MaxSelect: 1, Values: triggers}},
&schema.SchemaField{Name: "triggerdata", Type: schema.FieldTypeJson, Required: true}, &schema.SchemaField{Name: "triggerdata", Type: schema.FieldTypeJson, Required: true, Options: &schema.JsonOptions{MaxSize: 50_000}},
&schema.SchemaField{Name: "action", Type: schema.FieldTypeSelect, Required: true, Options: &schema.SelectOptions{MaxSelect: 1, Values: reactions}}, &schema.SchemaField{Name: "action", Type: schema.FieldTypeSelect, Required: true, Options: &schema.SelectOptions{MaxSelect: 1, Values: reactions}},
&schema.SchemaField{Name: "actiondata", Type: schema.FieldTypeJson, Required: true}, &schema.SchemaField{Name: "actiondata", Type: schema.FieldTypeJson, Required: true, Options: &schema.JsonOptions{MaxSize: 50_000}},
), ),
})) }))
} }

View File

@@ -0,0 +1,37 @@
package migrations
import (
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
)
const SystemUserID = "system"
func systemuserUp(db dbx.Builder) error {
dao := daos.New(db)
collection, err := dao.FindCollectionByNameOrId(UserCollectionName)
if err != nil {
return err
}
record := models.NewRecord(collection)
record.SetId(SystemUserID)
record.Set("name", "system")
record.Set("username", "system")
record.Set("verified", true)
return dao.SaveRecord(record)
}
func systemuserDown(db dbx.Builder) error {
dao := daos.New(db)
record, err := dao.FindRecordById(UserCollectionName, SystemUserID)
if err != nil {
return err
}
return dao.DeleteRecord(record)
}

View File

@@ -10,4 +10,5 @@ func Register() {
migrations.Register(defaultDataUp, nil, "1700000003_defaultdata.go") migrations.Register(defaultDataUp, nil, "1700000003_defaultdata.go")
migrations.Register(viewsUp, viewsDown, "1700000004_views.go") migrations.Register(viewsUp, viewsDown, "1700000004_views.go")
migrations.Register(reactionsUp, reactionsDown, "1700000005_reactions.go") migrations.Register(reactionsUp, reactionsDown, "1700000005_reactions.go")
migrations.Register(systemuserUp, systemuserDown, "1700000006_systemuser.go")
} }

View File

@@ -4,17 +4,33 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/SecurityBrewery/catalyst/migrations"
"github.com/SecurityBrewery/catalyst/reaction/action/python" "github.com/SecurityBrewery/catalyst/reaction/action/python"
"github.com/SecurityBrewery/catalyst/reaction/action/webhook" "github.com/SecurityBrewery/catalyst/reaction/action/webhook"
) )
func Run(ctx context.Context, actionName, actionData, payload string) ([]byte, error) { func Run(ctx context.Context, app core.App, actionName, actionData, payload string) ([]byte, error) {
action, err := decode(actionName, actionData) action, err := decode(actionName, actionData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if a, ok := action.(authenticatedAction); ok {
token, err := systemToken(app)
if err != nil {
return nil, fmt.Errorf("failed to get system token: %w", err)
}
a.SetToken(token)
}
return action.Run(ctx, payload) return action.Run(ctx, payload)
} }
@@ -22,6 +38,10 @@ type action interface {
Run(ctx context.Context, payload string) ([]byte, error) Run(ctx context.Context, payload string) ([]byte, error)
} }
type authenticatedAction interface {
SetToken(token string)
}
func decode(actionName, actionData string) (action, error) { func decode(actionName, actionData string) (action, error) {
switch actionName { switch actionName {
case "python": case "python":
@@ -42,3 +62,20 @@ func decode(actionName, actionData string) (action, error) {
return nil, fmt.Errorf("action %q not found", actionName) return nil, fmt.Errorf("action %q not found", actionName)
} }
} }
func systemToken(app core.App) (string, error) {
authRecord, err := app.Dao().FindAuthRecordByUsername(migrations.UserCollectionName, migrations.SystemUserID)
if err != nil {
return "", fmt.Errorf("failed to find system auth record: %w", err)
}
return security.NewJWT(
jwt.MapClaims{
"id": authRecord.Id,
"type": tokens.TypeAuthRecord,
"collectionId": authRecord.Collection().Id,
},
authRecord.TokenKey()+app.Settings().RecordAuthToken.Secret,
int64(time.Second*60),
)
}

View File

@@ -10,8 +10,14 @@ import (
) )
type Python struct { type Python struct {
Bootstrap string `json:"bootstrap"` Requirements string `json:"requirements"`
Script string `json:"script"` Script string `json:"script"`
token string
}
func (a *Python) SetToken(token string) {
a.token = token
} }
func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) { func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
@@ -22,7 +28,8 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
defer os.RemoveAll(tempDir) defer os.RemoveAll(tempDir)
if b, err := pythonSetup(ctx, tempDir); err != nil { b, err := pythonSetup(ctx, tempDir)
if err != nil {
var ee *exec.ExitError var ee *exec.ExitError
if errors.As(err, &ee) { if errors.As(err, &ee) {
b = append(b, ee.Stderr...) b = append(b, ee.Stderr...)
@@ -31,16 +38,17 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b)) return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b))
} }
if b, err := pythonRunBootstrap(ctx, tempDir, a.Bootstrap); err != nil { b, err = a.pythonInstallRequirements(ctx, tempDir)
if err != nil {
var ee *exec.ExitError var ee *exec.ExitError
if errors.As(err, &ee) { if errors.As(err, &ee) {
b = append(b, ee.Stderr...) b = append(b, ee.Stderr...)
} }
return nil, fmt.Errorf("failed to run bootstrap, %w: %s", err, string(b)) return nil, fmt.Errorf("failed to run install requirements, %w: %s", err, string(b))
} }
b, err := pythonRunScript(ctx, tempDir, a.Script, payload) b, err = a.pythonRunScript(ctx, tempDir, payload)
if err != nil { if err != nil {
var ee *exec.ExitError var ee *exec.ExitError
if errors.As(err, &ee) { if errors.As(err, &ee) {
@@ -63,35 +71,42 @@ func pythonSetup(ctx context.Context, tempDir string) ([]byte, error) {
return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output() return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output()
} }
func pythonRunBootstrap(ctx context.Context, tempDir, bootstrap string) ([]byte, error) { func (a *Python) pythonInstallRequirements(ctx context.Context, tempDir string) ([]byte, error) {
hasBootstrap := len(strings.TrimSpace(bootstrap)) > 0 hasRequirements := len(strings.TrimSpace(a.Requirements)) > 0
if !hasBootstrap { if !hasRequirements {
return nil, nil return nil, nil
} }
bootstrapPath := tempDir + "/requirements.txt" requirementsPath := tempDir + "/requirements.txt"
if err := os.WriteFile(bootstrapPath, []byte(bootstrap), 0o600); err != nil { if err := os.WriteFile(requirementsPath, []byte(a.Requirements), 0o600); err != nil {
return nil, err return nil, err
} }
// install dependencies // install dependencies
pipPath := tempDir + "/venv/bin/pip" pipPath := tempDir + "/venv/bin/pip"
return exec.CommandContext(ctx, pipPath, "install", "-r", bootstrapPath).Output() return exec.CommandContext(ctx, pipPath, "install", "-r", requirementsPath).Output()
} }
func pythonRunScript(ctx context.Context, tempDir, script, payload string) ([]byte, error) { func (a *Python) pythonRunScript(ctx context.Context, tempDir, payload string) ([]byte, error) {
scriptPath := tempDir + "/script.py" scriptPath := tempDir + "/script.py"
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil { if err := os.WriteFile(scriptPath, []byte(a.Script), 0o600); err != nil {
return nil, err return nil, err
} }
pythonPath := tempDir + "/venv/bin/python" pythonPath := tempDir + "/venv/bin/python"
return exec.CommandContext(ctx, pythonPath, scriptPath, payload).Output() cmd := exec.CommandContext(ctx, pythonPath, scriptPath, payload)
cmd.Env = []string{}
if a.token != "" {
cmd.Env = append(cmd.Env, "CATALYST_TOKEN="+a.token)
}
return cmd.Output()
} }
func findExec(name ...string) (string, error) { func findExec(name ...string) (string, error) {

View File

@@ -1,15 +1,19 @@
package python package python_test
import ( import (
"context" "context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/SecurityBrewery/catalyst/reaction/action/python"
) )
func TestPython_Run(t *testing.T) { func TestPython_Run(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
Bootstrap string Requirements string
Script string Script string
} }
@@ -68,13 +72,27 @@ func TestPython_Run(t *testing.T) {
want: nil, want: nil,
wantErr: assert.Error, wantErr: assert.Error,
}, },
{
name: "requests",
fields: fields{
Requirements: "requests",
Script: "import requests\nprint(requests.get('https://xkcd.com/2961/info.0.json').text)",
},
args: args{
payload: "test",
},
want: []byte("{\"month\": \"7\", \"num\": 2961, \"link\": \"\", \"year\": \"2024\", \"news\": \"\", \"safe_title\": \"CrowdStrike\", \"transcript\": \"\", \"alt\": \"We were going to try swordfighting, but all my compiling is on hold.\", \"img\": \"https://imgs.xkcd.com/comics/crowdstrike.png\", \"title\": \"CrowdStrike\", \"day\": \"19\"}\n"),
wantErr: assert.NoError,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background() ctx := context.Background()
a := &Python{ a := &python.Python{
Bootstrap: tt.fields.Bootstrap, Requirements: tt.fields.Requirements,
Script: tt.fields.Script, Script: tt.fields.Script,
} }
got, err := a.Run(ctx, tt.args.payload) got, err := a.Run(ctx, tt.args.payload)

View File

@@ -1,12 +1,16 @@
package webhook package webhook_test
import ( import (
"bytes" "bytes"
"io" "io"
"testing" "testing"
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
) )
func TestEncodeBody(t *testing.T) { func TestEncodeBody(t *testing.T) {
t.Parallel()
type args struct { type args struct {
requestBody io.Reader requestBody io.Reader
} }
@@ -36,7 +40,9 @@ func TestEncodeBody(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, got1 := EncodeBody(tt.args.requestBody) t.Parallel()
got, got1 := webhook.EncodeBody(tt.args.requestBody)
if got != tt.want { if got != tt.want {
t.Errorf("EncodeBody() got = %v, want %v", got, tt.want) t.Errorf("EncodeBody() got = %v, want %v", got, tt.want)
} }

View File

@@ -16,11 +16,15 @@ import (
) )
func TestWebhook_Run(t *testing.T) { func TestWebhook_Run(t *testing.T) {
t.Parallel()
server := catalystTesting.NewRecordingServer() server := catalystTesting.NewRecordingServer()
go http.ListenAndServe("127.0.0.1:12347", server) //nolint:gosec,errcheck go http.ListenAndServe("127.0.0.1:12347", server) //nolint:gosec,errcheck
time.Sleep(1 * time.Second) if err := catalystTesting.WaitForStatus("http://127.0.0.1:12347/health", http.StatusOK, 5*time.Second); err != nil {
t.Fatal(err)
}
type fields struct { type fields struct {
Headers map[string]string Headers map[string]string
@@ -50,10 +54,10 @@ func TestWebhook_Run(t *testing.T) {
want: map[string]any{ want: map[string]any{
"statusCode": 200, "statusCode": 200,
"headers": map[string]any{ "headers": map[string]any{
"Content-Length": []any{"13"}, "Content-Length": []any{"14"},
"Content-Type": []any{"text/plain; charset=utf-8"}, "Content-Type": []any{"application/json; charset=UTF-8"},
}, },
"body": `{"test":true}`, "body": "{\"test\":true}\n",
"isBase64Encoded": false, "isBase64Encoded": false,
}, },
wantErr: assert.NoError, wantErr: assert.NoError,
@@ -61,6 +65,8 @@ func TestWebhook_Run(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background() ctx := context.Background()
a := &webhook.Webhook{ a := &webhook.Webhook{

View File

@@ -1,13 +1,13 @@
package reaction package reaction
import ( import (
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase"
"github.com/SecurityBrewery/catalyst/reaction/trigger/hook" "github.com/SecurityBrewery/catalyst/reaction/trigger/hook"
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook" "github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
) )
func BindHooks(app core.App) { func BindHooks(pb *pocketbase.PocketBase, test bool) {
hook.BindHooks(app) hook.BindHooks(pb, test)
webhook.BindHooks(app) webhook.BindHooks(pb)
} }

View File

@@ -1,12 +1,15 @@
package hook package hook
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"slices" "slices"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/daos"
@@ -22,43 +25,40 @@ type Hook struct {
Events []string `json:"events"` Events []string `json:"events"`
} }
func BindHooks(app core.App) { func BindHooks(pb *pocketbase.PocketBase, test bool) {
app.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error { pb.App.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error {
if err := hook(app.Dao(), "create", e.Collection.Name, e.Record, e.HttpContext); err != nil { return hook(e.HttpContext, pb.App, "create", e.Collection.Name, e.Record, test)
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
}) })
app.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error { pb.App.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error {
if err := hook(app.Dao(), "update", e.Collection.Name, e.Record, e.HttpContext); err != nil { return hook(e.HttpContext, pb.App, "update", e.Collection.Name, e.Record, test)
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
}) })
app.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error { pb.App.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error {
if err := hook(app.Dao(), "delete", e.Collection.Name, e.Record, e.HttpContext); err != nil { return hook(e.HttpContext, pb.App, "delete", e.Collection.Name, e.Record, test)
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
}) })
} }
func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx echo.Context) error { func hook(ctx echo.Context, app core.App, event, collection string, record *models.Record, test bool) error {
auth, _ := ctx.Get(apis.ContextAuthRecordKey).(*models.Record) auth, _ := ctx.Get(apis.ContextAuthRecordKey).(*models.Record)
admin, _ := ctx.Get(apis.ContextAdminKey).(*models.Admin) admin, _ := ctx.Get(apis.ContextAdminKey).(*models.Admin)
hook, found, err := findByHookTrigger(dao, collection, event) if !test {
if err != nil { go mustRunHook(app, collection, event, record, auth, admin)
return fmt.Errorf("failed to find hook reaction: %w", err) } else {
mustRunHook(app, collection, event, record, auth, admin)
} }
if !found {
return nil return nil
} }
func mustRunHook(app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) {
ctx := context.Background()
if err := runHook(ctx, app, collection, event, record, auth, admin); err != nil {
slog.ErrorContext(ctx, fmt.Sprintf("failed to run hook reaction: %v", err))
}
}
func runHook(ctx context.Context, app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) error {
payload, err := json.Marshal(&webhook.Payload{ payload, err := json.Marshal(&webhook.Payload{
Action: event, Action: event,
Collection: collection, Collection: collection,
@@ -67,10 +67,19 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec
Admin: admin, Admin: admin,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err) return fmt.Errorf("failed to marshal webhook payload: %w", err)
} }
_, err = action.Run(ctx.Request().Context(), hook.GetString("action"), hook.GetString("actiondata"), string(payload)) hook, found, err := findByHookTrigger(app.Dao(), collection, event)
if err != nil {
return fmt.Errorf("failed to find hook by trigger: %w", err)
}
if !found {
return nil
}
_, err = action.Run(ctx, app, hook.GetString("action"), hook.GetString("actiondata"), string(payload))
if err != nil { if err != nil {
return fmt.Errorf("failed to run hook reaction: %w", err) return fmt.Errorf("failed to run hook reaction: %w", err)
} }
@@ -81,7 +90,7 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec
func findByHookTrigger(dao *daos.Dao, collection, event string) (*models.Record, bool, error) { func findByHookTrigger(dao *daos.Dao, collection, event string) (*models.Record, bool, error) {
records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "hook"}) records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "hook"})
if err != nil { if err != nil {
return nil, false, err return nil, false, fmt.Errorf("failed to find hook reaction: %w", err)
} }
if len(records) == 0 { if len(records) == 0 {

View File

@@ -15,8 +15,8 @@ type Request struct {
IsBase64Encoded bool `json:"isBase64Encoded"` IsBase64Encoded bool `json:"isBase64Encoded"`
} }
// isJSON checks if the data is JSON. // IsJSON checks if the data is JSON.
func isJSON(data []byte) bool { func IsJSON(data []byte) bool {
var msg json.RawMessage var msg json.RawMessage
return json.Unmarshal(data, &msg) == nil return json.Unmarshal(data, &msg) == nil

View File

@@ -1,8 +1,14 @@
package webhook package webhook_test
import "testing" import (
"testing"
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
)
func Test_isJSON(t *testing.T) { func Test_isJSON(t *testing.T) {
t.Parallel()
type args struct { type args struct {
data []byte data []byte
} }
@@ -29,7 +35,9 @@ func Test_isJSON(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := isJSON(tt.args.data); got != tt.want { t.Parallel()
if got := webhook.IsJSON(tt.args.data); got != tt.want {
t.Errorf("isJSON() = %v, want %v", got, tt.want) t.Errorf("isJSON() = %v, want %v", got, tt.want)
} }
}) })

View File

@@ -9,6 +9,7 @@ import (
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/daos"
@@ -26,22 +27,22 @@ type Webhook struct {
const prefix = "/reaction/" const prefix = "/reaction/"
func BindHooks(app core.App) { func BindHooks(pb *pocketbase.PocketBase) {
app.OnBeforeServe().Add(func(e *core.ServeEvent) error { pb.OnBeforeServe().Add(func(e *core.ServeEvent) error {
e.Router.Any(prefix+"*", handle(e.App.Dao())) e.Router.Any(prefix+"*", handle(e.App))
return nil return nil
}) })
} }
func handle(dao *daos.Dao) func(c echo.Context) error { func handle(app core.App) func(c echo.Context) error {
return func(c echo.Context) error { return func(c echo.Context) error {
record, payload, apiErr := parseRequest(dao, c.Request()) record, payload, apiErr := parseRequest(app.Dao(), c.Request())
if apiErr != nil { if apiErr != nil {
return apiErr return apiErr
} }
output, err := action.Run(c.Request().Context(), record.GetString("action"), record.GetString("actiondata"), string(payload)) output, err := action.Run(c.Request().Context(), app, record.GetString("action"), record.GetString("actiondata"), string(payload))
if err != nil { if err != nil {
return apis.NewApiError(http.StatusInternalServerError, err.Error(), nil) return apis.NewApiError(http.StatusInternalServerError, err.Error(), nil)
} }
@@ -138,7 +139,7 @@ func writeOutput(c echo.Context, output []byte) error {
} }
} }
if isJSON(output) { if IsJSON(output) {
return c.JSON(http.StatusOK, json.RawMessage(output)) return c.JSON(http.StatusOK, json.RawMessage(output))
} }

View File

@@ -6,18 +6,16 @@ import (
) )
func TestReactionsCollection(t *testing.T) { func TestReactionsCollection(t *testing.T) {
baseApp, adminToken, analystToken, baseAppCleanup := BaseApp(t) t.Parallel()
defer baseAppCleanup()
testSets := []authMatrixText{ testSets := []catalystTest{
{ {
baseTest: BaseTest{ baseTest: BaseTest{
Name: "ListReactions", Name: "ListReactions",
Method: http.MethodGet, Method: http.MethodGet,
URL: "/api/collections/reactions/records", URL: "/api/collections/reactions/records",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
@@ -29,7 +27,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"totalItems":3`, `"totalItems":3`,
@@ -42,7 +40,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"totalItems":3`, `"totalItems":3`,
@@ -68,9 +66,8 @@ func TestReactionsCollection(t *testing.T) {
"action": "python", "action": "python",
"actiondata": map[string]any{"script": "print('Hello, World!')"}, "actiondata": map[string]any{"script": "print('Hello, World!')"},
}), }),
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusBadRequest, ExpectedStatus: http.StatusBadRequest,
@@ -80,7 +77,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"name":"test"`, `"name":"test"`,
@@ -97,7 +94,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"name":"test"`, `"name":"test"`,
@@ -120,9 +117,8 @@ func TestReactionsCollection(t *testing.T) {
Method: http.MethodGet, Method: http.MethodGet,
RequestHeaders: map[string]string{"Content-Type": "application/json"}, RequestHeaders: map[string]string{"Content-Type": "application/json"},
URL: "/api/collections/reactions/records/r_reaction", URL: "/api/collections/reactions/records/r_reaction",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusNotFound, ExpectedStatus: http.StatusNotFound,
@@ -132,7 +128,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"id":"r_reaction"`, `"id":"r_reaction"`,
@@ -141,7 +137,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"id":"r_reaction"`, `"id":"r_reaction"`,
@@ -157,9 +153,8 @@ func TestReactionsCollection(t *testing.T) {
RequestHeaders: map[string]string{"Content-Type": "application/json"}, RequestHeaders: map[string]string{"Content-Type": "application/json"},
URL: "/api/collections/reactions/records/r_reaction", URL: "/api/collections/reactions/records/r_reaction",
Body: s(map[string]any{"name": "update"}), Body: s(map[string]any{"name": "update"}),
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusNotFound, ExpectedStatus: http.StatusNotFound,
@@ -169,7 +164,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"id":"r_reaction"`, `"id":"r_reaction"`,
@@ -184,7 +179,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"id":"r_reaction"`, `"id":"r_reaction"`,
@@ -204,9 +199,8 @@ func TestReactionsCollection(t *testing.T) {
Name: "DeleteReaction", Name: "DeleteReaction",
Method: http.MethodDelete, Method: http.MethodDelete,
URL: "/api/collections/reactions/records/r_reaction", URL: "/api/collections/reactions/records/r_reaction",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusNotFound, ExpectedStatus: http.StatusNotFound,
@@ -216,7 +210,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusNoContent, ExpectedStatus: http.StatusNoContent,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnModelAfterDelete": 1, "OnModelAfterDelete": 1,
@@ -227,7 +221,7 @@ func TestReactionsCollection(t *testing.T) {
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusNoContent, ExpectedStatus: http.StatusNoContent,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnModelAfterDelete": 1, "OnModelAfterDelete": 1,
@@ -241,9 +235,14 @@ func TestReactionsCollection(t *testing.T) {
} }
for _, testSet := range testSets { for _, testSet := range testSets {
t.Run(testSet.baseTest.Name, func(t *testing.T) { t.Run(testSet.baseTest.Name, func(t *testing.T) {
for _, authBasedExpectation := range testSet.authBasedExpectations { t.Parallel()
scenario := mergeScenario(testSet.baseTest, authBasedExpectation)
scenario.Test(t) for _, userTest := range testSet.userTests {
t.Run(userTest.Name, func(t *testing.T) {
t.Parallel()
runMatrixTest(t, testSet.baseTest, userTest)
})
} }
}) })
} }

36
testing/counter.go Normal file
View File

@@ -0,0 +1,36 @@
package testing
import "sync"
type Counter struct {
mux sync.Mutex
counts map[string]int
}
func NewCounter() *Counter {
return &Counter{
counts: make(map[string]int),
}
}
func (c *Counter) Increment(name string) {
c.mux.Lock()
defer c.mux.Unlock()
if _, ok := c.counts[name]; !ok {
c.counts[name] = 0
}
c.counts[name]++
}
func (c *Counter) Count(name string) int {
c.mux.Lock()
defer c.mux.Unlock()
if _, ok := c.counts[name]; !ok {
return 0
}
return c.counts[name]
}

41
testing/counter_test.go Normal file
View File

@@ -0,0 +1,41 @@
package testing
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCounter(t *testing.T) {
t.Parallel()
type args struct {
name string
repeat int
}
tests := []struct {
name string
args args
want int
}{
{
name: "Test Counter",
args: args{name: "test", repeat: 5},
want: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c := NewCounter()
for range tt.args.repeat {
c.Increment(tt.args.name)
}
assert.Equal(t, tt.want, c.Count(tt.args.name))
})
}
}

37
testing/http.go Normal file
View File

@@ -0,0 +1,37 @@
package testing
import (
"context"
"errors"
"net/http"
"time"
)
func WaitForStatus(url string, status int, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
start := time.Now()
for {
if time.Since(start) > timeout {
return errors.New("timeout")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err == nil && resp.StatusCode == status {
resp.Body.Close()
break
}
time.Sleep(100 * time.Millisecond)
}
return nil
}

View File

@@ -3,44 +3,36 @@ package testing
import ( import (
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWebhookReactions(t *testing.T) { func TestWebhookReactions(t *testing.T) {
baseApp, adminToken, analystToken, baseAppCleanup := BaseApp(t) t.Parallel()
defer baseAppCleanup()
server := NewRecordingServer() server := NewRecordingServer()
go http.ListenAndServe("127.0.0.1:12345", server) //nolint:gosec,errcheck go http.ListenAndServe("127.0.0.1:12345", server) //nolint:gosec,errcheck
testSets := []authMatrixText{ if err := WaitForStatus("http://127.0.0.1:12345/health", http.StatusOK, 5*time.Second); err != nil {
t.Fatal(err)
}
testSets := []catalystTest{
{ {
baseTest: BaseTest{ baseTest: BaseTest{
Name: "TriggerWebhookReaction", Name: "TriggerWebhookReaction",
Method: http.MethodGet, Method: http.MethodGet,
RequestHeaders: map[string]string{"Authorization": "Bearer 1234567890"},
URL: "/reaction/test", URL: "/reaction/test",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{`Hello, World!`}, ExpectedContent: []string{`Hello, World!`},
}, },
{
Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken},
ExpectedStatus: http.StatusOK,
ExpectedContent: []string{`Hello, World!`},
},
{
Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken},
ExpectedStatus: http.StatusOK,
ExpectedContent: []string{`Hello, World!`},
},
}, },
}, },
{ {
@@ -48,48 +40,43 @@ func TestWebhookReactions(t *testing.T) {
Name: "TriggerWebhookReaction2", Name: "TriggerWebhookReaction2",
Method: http.MethodGet, Method: http.MethodGet,
URL: "/reaction/test2", URL: "/reaction/test2",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{`"test":true`}, ExpectedContent: []string{`"test":true`},
}, },
{
Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken},
ExpectedStatus: http.StatusOK,
ExpectedContent: []string{`"test":true`},
},
{
Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken},
ExpectedStatus: http.StatusOK,
ExpectedContent: []string{`"test":true`},
},
}, },
}, },
} }
for _, testSet := range testSets { for _, testSet := range testSets {
t.Run(testSet.baseTest.Name, func(t *testing.T) { t.Run(testSet.baseTest.Name, func(t *testing.T) {
for _, authBasedExpectation := range testSet.authBasedExpectations { t.Parallel()
scenario := mergeScenario(testSet.baseTest, authBasedExpectation)
scenario.Test(t) for _, userTest := range testSet.userTests {
t.Run(userTest.Name, func(t *testing.T) {
t.Parallel()
runMatrixTest(t, testSet.baseTest, userTest)
})
} }
}) })
} }
} }
func TestHookReactions(t *testing.T) { func TestHookReactions(t *testing.T) {
baseApp, _, analystToken, baseAppCleanup := BaseApp(t) t.Parallel()
defer baseAppCleanup()
server := NewRecordingServer() server := NewRecordingServer()
go http.ListenAndServe("127.0.0.1:12346", server) //nolint:gosec,errcheck go http.ListenAndServe("127.0.0.1:12346", server) //nolint:gosec,errcheck
testSets := []authMatrixText{ if err := WaitForStatus("http://127.0.0.1:12346/health", http.StatusOK, 5*time.Second); err != nil {
t.Fatal(err)
}
testSets := []catalystTest{
{ {
baseTest: BaseTest{ baseTest: BaseTest{
Name: "TriggerHookReaction", Name: "TriggerHookReaction",
@@ -99,9 +86,8 @@ func TestHookReactions(t *testing.T) {
Body: s(map[string]any{ Body: s(map[string]any{
"name": "test", "name": "test",
}), }),
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
// { // {
// Name: "Unauthorized", // Name: "Unauthorized",
// ExpectedStatus: http.StatusOK, // ExpectedStatus: http.StatusOK,
@@ -109,7 +95,7 @@ func TestHookReactions(t *testing.T) {
// }, // },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"collectionName":"tickets"`, `"collectionName":"tickets"`,
@@ -133,12 +119,17 @@ func TestHookReactions(t *testing.T) {
} }
for _, testSet := range testSets { for _, testSet := range testSets {
t.Run(testSet.baseTest.Name, func(t *testing.T) { t.Run(testSet.baseTest.Name, func(t *testing.T) {
for _, authBasedExpectation := range testSet.authBasedExpectations { t.Parallel()
scenario := mergeScenario(testSet.baseTest, authBasedExpectation)
scenario.Test(t) for _, userTest := range testSet.userTests {
} t.Run(userTest.Name, func(t *testing.T) {
t.Parallel()
runMatrixTest(t, testSet.baseTest, userTest)
require.NotEmpty(t, server.Entries) require.NotEmpty(t, server.Entries)
}) })
} }
})
}
} }

View File

@@ -1,19 +1,38 @@
package testing package testing
import "net/http" import (
"net/http"
"github.com/labstack/echo/v5"
)
type RecordingServer struct { type RecordingServer struct {
server *echo.Echo
Entries []string Entries []string
} }
func NewRecordingServer() *RecordingServer { func NewRecordingServer() *RecordingServer {
return &RecordingServer{} e := echo.New()
e.GET("/health", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]any{
"status": "ok",
})
})
e.Any("/*", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]any{
"test": true,
})
})
return &RecordingServer{
server: e,
}
} }
func (s *RecordingServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *RecordingServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.Entries = append(s.Entries, r.URL.Path) s.Entries = append(s.Entries, r.URL.Path)
w.WriteHeader(http.StatusOK) s.server.ServeHTTP(w, r)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"test":true}`)) //nolint:errcheck
} }

View File

@@ -6,30 +6,28 @@ import (
) )
func Test_Routes(t *testing.T) { func Test_Routes(t *testing.T) {
baseApp, adminToken, analystToken, baseAppCleanup := BaseApp(t) t.Parallel()
defer baseAppCleanup()
testSets := []authMatrixText{ testSets := []catalystTest{
{ {
baseTest: BaseTest{ baseTest: BaseTest{
Name: "Root", Name: "Root",
Method: http.MethodGet, Method: http.MethodGet,
URL: "/", URL: "/",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusFound, ExpectedStatus: http.StatusFound,
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusFound, ExpectedStatus: http.StatusFound,
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusFound, ExpectedStatus: http.StatusFound,
}, },
}, },
@@ -39,30 +37,29 @@ func Test_Routes(t *testing.T) {
Name: "Config", Name: "Config",
Method: http.MethodGet, Method: http.MethodGet,
URL: "/api/config", URL: "/api/config",
TestAppFactory: AppFactory(baseApp),
}, },
authBasedExpectations: []AuthBasedExpectation{ userTests: []UserTest{
{ {
Name: "Unauthorized", Name: "Unauthorized",
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"flags":null`, `"flags":[]`,
}, },
}, },
{ {
Name: "Analyst", Name: "Analyst",
RequestHeaders: map[string]string{"Authorization": analystToken}, AuthRecord: analystEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"flags":null`, `"flags":[]`,
}, },
}, },
{ {
Name: "Admin", Name: "Admin",
RequestHeaders: map[string]string{"Authorization": adminToken}, Admin: adminEmail,
ExpectedStatus: http.StatusOK, ExpectedStatus: http.StatusOK,
ExpectedContent: []string{ ExpectedContent: []string{
`"flags":null`, `"flags":[]`,
}, },
}, },
}, },
@@ -70,9 +67,14 @@ func Test_Routes(t *testing.T) {
} }
for _, testSet := range testSets { for _, testSet := range testSets {
t.Run(testSet.baseTest.Name, func(t *testing.T) { t.Run(testSet.baseTest.Name, func(t *testing.T) {
for _, authBasedExpectation := range testSet.authBasedExpectations { t.Parallel()
scenario := mergeScenario(testSet.baseTest, authBasedExpectation)
scenario.Test(t) for _, userTest := range testSet.userTests {
t.Run(userTest.Name, func(t *testing.T) {
t.Parallel()
runMatrixTest(t, testSet.baseTest, userTest)
})
} }
}) })
} }

160
testing/testapp.go Normal file
View File

@@ -0,0 +1,160 @@
package testing
import (
"fmt"
"os"
"testing"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tokens"
"github.com/SecurityBrewery/catalyst/app"
"github.com/SecurityBrewery/catalyst/migrations"
)
func App(t *testing.T) (*pocketbase.PocketBase, *Counter, func()) {
t.Helper()
temp, err := os.MkdirTemp("", "catalyst_test_data")
if err != nil {
t.Fatal(err)
}
baseApp, err := app.App(temp, true)
if err != nil {
t.Fatal(err)
}
baseApp.Settings().Logs.MaxDays = 0
defaultTestData(t, baseApp)
counter := countEvents(baseApp)
return baseApp, counter, func() { _ = os.RemoveAll(temp) }
}
func generateAdminToken(t *testing.T, baseApp core.App, email string) (string, error) {
t.Helper()
admin, err := baseApp.Dao().FindAdminByEmail(email)
if err != nil {
return "", fmt.Errorf("failed to find admin: %w", err)
}
return tokens.NewAdminAuthToken(baseApp, admin)
}
func generateRecordToken(t *testing.T, baseApp core.App, email string) (string, error) {
t.Helper()
record, err := baseApp.Dao().FindAuthRecordByEmail(migrations.UserCollectionName, email)
if err != nil {
return "", fmt.Errorf("failed to find record: %w", err)
}
return tokens.NewRecordAuthToken(baseApp, record)
}
func countEvents(t *pocketbase.PocketBase) *Counter {
c := NewCounter()
t.OnBeforeApiError().Add(count[*core.ApiErrorEvent](c, "OnBeforeApiError"))
t.OnBeforeApiError().Add(count[*core.ApiErrorEvent](c, "OnBeforeApiError"))
t.OnAfterApiError().Add(count[*core.ApiErrorEvent](c, "OnAfterApiError"))
t.OnModelBeforeCreate().Add(count[*core.ModelEvent](c, "OnModelBeforeCreate"))
t.OnModelAfterCreate().Add(count[*core.ModelEvent](c, "OnModelAfterCreate"))
t.OnModelBeforeUpdate().Add(count[*core.ModelEvent](c, "OnModelBeforeUpdate"))
t.OnModelAfterUpdate().Add(count[*core.ModelEvent](c, "OnModelAfterUpdate"))
t.OnModelBeforeDelete().Add(count[*core.ModelEvent](c, "OnModelBeforeDelete"))
t.OnModelAfterDelete().Add(count[*core.ModelEvent](c, "OnModelAfterDelete"))
t.OnRecordsListRequest().Add(count[*core.RecordsListEvent](c, "OnRecordsListRequest"))
t.OnRecordViewRequest().Add(count[*core.RecordViewEvent](c, "OnRecordViewRequest"))
t.OnRecordBeforeCreateRequest().Add(count[*core.RecordCreateEvent](c, "OnRecordBeforeCreateRequest"))
t.OnRecordAfterCreateRequest().Add(count[*core.RecordCreateEvent](c, "OnRecordAfterCreateRequest"))
t.OnRecordBeforeUpdateRequest().Add(count[*core.RecordUpdateEvent](c, "OnRecordBeforeUpdateRequest"))
t.OnRecordAfterUpdateRequest().Add(count[*core.RecordUpdateEvent](c, "OnRecordAfterUpdateRequest"))
t.OnRecordBeforeDeleteRequest().Add(count[*core.RecordDeleteEvent](c, "OnRecordBeforeDeleteRequest"))
t.OnRecordAfterDeleteRequest().Add(count[*core.RecordDeleteEvent](c, "OnRecordAfterDeleteRequest"))
t.OnRecordAuthRequest().Add(count[*core.RecordAuthEvent](c, "OnRecordAuthRequest"))
t.OnRecordBeforeAuthWithPasswordRequest().Add(count[*core.RecordAuthWithPasswordEvent](c, "OnRecordBeforeAuthWithPasswordRequest"))
t.OnRecordAfterAuthWithPasswordRequest().Add(count[*core.RecordAuthWithPasswordEvent](c, "OnRecordAfterAuthWithPasswordRequest"))
t.OnRecordBeforeAuthWithOAuth2Request().Add(count[*core.RecordAuthWithOAuth2Event](c, "OnRecordBeforeAuthWithOAuth2Request"))
t.OnRecordAfterAuthWithOAuth2Request().Add(count[*core.RecordAuthWithOAuth2Event](c, "OnRecordAfterAuthWithOAuth2Request"))
t.OnRecordBeforeAuthRefreshRequest().Add(count[*core.RecordAuthRefreshEvent](c, "OnRecordBeforeAuthRefreshRequest"))
t.OnRecordAfterAuthRefreshRequest().Add(count[*core.RecordAuthRefreshEvent](c, "OnRecordAfterAuthRefreshRequest"))
t.OnRecordBeforeRequestPasswordResetRequest().Add(count[*core.RecordRequestPasswordResetEvent](c, "OnRecordBeforeRequestPasswordResetRequest"))
t.OnRecordAfterRequestPasswordResetRequest().Add(count[*core.RecordRequestPasswordResetEvent](c, "OnRecordAfterRequestPasswordResetRequest"))
t.OnRecordBeforeConfirmPasswordResetRequest().Add(count[*core.RecordConfirmPasswordResetEvent](c, "OnRecordBeforeConfirmPasswordResetRequest"))
t.OnRecordAfterConfirmPasswordResetRequest().Add(count[*core.RecordConfirmPasswordResetEvent](c, "OnRecordAfterConfirmPasswordResetRequest"))
t.OnRecordBeforeRequestVerificationRequest().Add(count[*core.RecordRequestVerificationEvent](c, "OnRecordBeforeRequestVerificationRequest"))
t.OnRecordAfterRequestVerificationRequest().Add(count[*core.RecordRequestVerificationEvent](c, "OnRecordAfterRequestVerificationRequest"))
t.OnRecordBeforeConfirmVerificationRequest().Add(count[*core.RecordConfirmVerificationEvent](c, "OnRecordBeforeConfirmVerificationRequest"))
t.OnRecordAfterConfirmVerificationRequest().Add(count[*core.RecordConfirmVerificationEvent](c, "OnRecordAfterConfirmVerificationRequest"))
t.OnRecordBeforeRequestEmailChangeRequest().Add(count[*core.RecordRequestEmailChangeEvent](c, "OnRecordBeforeRequestEmailChangeRequest"))
t.OnRecordAfterRequestEmailChangeRequest().Add(count[*core.RecordRequestEmailChangeEvent](c, "OnRecordAfterRequestEmailChangeRequest"))
t.OnRecordBeforeConfirmEmailChangeRequest().Add(count[*core.RecordConfirmEmailChangeEvent](c, "OnRecordBeforeConfirmEmailChangeRequest"))
t.OnRecordAfterConfirmEmailChangeRequest().Add(count[*core.RecordConfirmEmailChangeEvent](c, "OnRecordAfterConfirmEmailChangeRequest"))
t.OnRecordListExternalAuthsRequest().Add(count[*core.RecordListExternalAuthsEvent](c, "OnRecordListExternalAuthsRequest"))
t.OnRecordBeforeUnlinkExternalAuthRequest().Add(count[*core.RecordUnlinkExternalAuthEvent](c, "OnRecordBeforeUnlinkExternalAuthRequest"))
t.OnRecordAfterUnlinkExternalAuthRequest().Add(count[*core.RecordUnlinkExternalAuthEvent](c, "OnRecordAfterUnlinkExternalAuthRequest"))
t.OnMailerBeforeAdminResetPasswordSend().Add(count[*core.MailerAdminEvent](c, "OnMailerBeforeAdminResetPasswordSend"))
t.OnMailerAfterAdminResetPasswordSend().Add(count[*core.MailerAdminEvent](c, "OnMailerAfterAdminResetPasswordSend"))
t.OnMailerBeforeRecordResetPasswordSend().Add(count[*core.MailerRecordEvent](c, "OnMailerBeforeRecordResetPasswordSend"))
t.OnMailerAfterRecordResetPasswordSend().Add(count[*core.MailerRecordEvent](c, "OnMailerAfterRecordResetPasswordSend"))
t.OnMailerBeforeRecordVerificationSend().Add(count[*core.MailerRecordEvent](c, "OnMailerBeforeRecordVerificationSend"))
t.OnMailerAfterRecordVerificationSend().Add(count[*core.MailerRecordEvent](c, "OnMailerAfterRecordVerificationSend"))
t.OnMailerBeforeRecordChangeEmailSend().Add(count[*core.MailerRecordEvent](c, "OnMailerBeforeRecordChangeEmailSend"))
t.OnMailerAfterRecordChangeEmailSend().Add(count[*core.MailerRecordEvent](c, "OnMailerAfterRecordChangeEmailSend"))
t.OnRealtimeConnectRequest().Add(count[*core.RealtimeConnectEvent](c, "OnRealtimeConnectRequest"))
t.OnRealtimeDisconnectRequest().Add(count[*core.RealtimeDisconnectEvent](c, "OnRealtimeDisconnectRequest"))
t.OnRealtimeBeforeMessageSend().Add(count[*core.RealtimeMessageEvent](c, "OnRealtimeBeforeMessageSend"))
t.OnRealtimeAfterMessageSend().Add(count[*core.RealtimeMessageEvent](c, "OnRealtimeAfterMessageSend"))
t.OnRealtimeBeforeSubscribeRequest().Add(count[*core.RealtimeSubscribeEvent](c, "OnRealtimeBeforeSubscribeRequest"))
t.OnRealtimeAfterSubscribeRequest().Add(count[*core.RealtimeSubscribeEvent](c, "OnRealtimeAfterSubscribeRequest"))
t.OnSettingsListRequest().Add(count[*core.SettingsListEvent](c, "OnSettingsListRequest"))
t.OnSettingsBeforeUpdateRequest().Add(count[*core.SettingsUpdateEvent](c, "OnSettingsBeforeUpdateRequest"))
t.OnSettingsAfterUpdateRequest().Add(count[*core.SettingsUpdateEvent](c, "OnSettingsAfterUpdateRequest"))
t.OnCollectionsListRequest().Add(count[*core.CollectionsListEvent](c, "OnCollectionsListRequest"))
t.OnCollectionViewRequest().Add(count[*core.CollectionViewEvent](c, "OnCollectionViewRequest"))
t.OnCollectionBeforeCreateRequest().Add(count[*core.CollectionCreateEvent](c, "OnCollectionBeforeCreateRequest"))
t.OnCollectionAfterCreateRequest().Add(count[*core.CollectionCreateEvent](c, "OnCollectionAfterCreateRequest"))
t.OnCollectionBeforeUpdateRequest().Add(count[*core.CollectionUpdateEvent](c, "OnCollectionBeforeUpdateRequest"))
t.OnCollectionAfterUpdateRequest().Add(count[*core.CollectionUpdateEvent](c, "OnCollectionAfterUpdateRequest"))
t.OnCollectionBeforeDeleteRequest().Add(count[*core.CollectionDeleteEvent](c, "OnCollectionBeforeDeleteRequest"))
t.OnCollectionAfterDeleteRequest().Add(count[*core.CollectionDeleteEvent](c, "OnCollectionAfterDeleteRequest"))
t.OnCollectionsBeforeImportRequest().Add(count[*core.CollectionsImportEvent](c, "OnCollectionsBeforeImportRequest"))
t.OnCollectionsAfterImportRequest().Add(count[*core.CollectionsImportEvent](c, "OnCollectionsAfterImportRequest"))
t.OnAdminsListRequest().Add(count[*core.AdminsListEvent](c, "OnAdminsListRequest"))
t.OnAdminViewRequest().Add(count[*core.AdminViewEvent](c, "OnAdminViewRequest"))
t.OnAdminBeforeCreateRequest().Add(count[*core.AdminCreateEvent](c, "OnAdminBeforeCreateRequest"))
t.OnAdminAfterCreateRequest().Add(count[*core.AdminCreateEvent](c, "OnAdminAfterCreateRequest"))
t.OnAdminBeforeUpdateRequest().Add(count[*core.AdminUpdateEvent](c, "OnAdminBeforeUpdateRequest"))
t.OnAdminAfterUpdateRequest().Add(count[*core.AdminUpdateEvent](c, "OnAdminAfterUpdateRequest"))
t.OnAdminBeforeDeleteRequest().Add(count[*core.AdminDeleteEvent](c, "OnAdminBeforeDeleteRequest"))
t.OnAdminAfterDeleteRequest().Add(count[*core.AdminDeleteEvent](c, "OnAdminAfterDeleteRequest"))
t.OnAdminAuthRequest().Add(count[*core.AdminAuthEvent](c, "OnAdminAuthRequest"))
t.OnAdminBeforeAuthWithPasswordRequest().Add(count[*core.AdminAuthWithPasswordEvent](c, "OnAdminBeforeAuthWithPasswordRequest"))
t.OnAdminAfterAuthWithPasswordRequest().Add(count[*core.AdminAuthWithPasswordEvent](c, "OnAdminAfterAuthWithPasswordRequest"))
t.OnAdminBeforeAuthRefreshRequest().Add(count[*core.AdminAuthRefreshEvent](c, "OnAdminBeforeAuthRefreshRequest"))
t.OnAdminAfterAuthRefreshRequest().Add(count[*core.AdminAuthRefreshEvent](c, "OnAdminAfterAuthRefreshRequest"))
t.OnAdminBeforeRequestPasswordResetRequest().Add(count[*core.AdminRequestPasswordResetEvent](c, "OnAdminBeforeRequestPasswordResetRequest"))
t.OnAdminAfterRequestPasswordResetRequest().Add(count[*core.AdminRequestPasswordResetEvent](c, "OnAdminAfterRequestPasswordResetRequest"))
t.OnAdminBeforeConfirmPasswordResetRequest().Add(count[*core.AdminConfirmPasswordResetEvent](c, "OnAdminBeforeConfirmPasswordResetRequest"))
t.OnAdminAfterConfirmPasswordResetRequest().Add(count[*core.AdminConfirmPasswordResetEvent](c, "OnAdminAfterConfirmPasswordResetRequest"))
t.OnFileDownloadRequest().Add(count[*core.FileDownloadEvent](c, "OnFileDownloadRequest"))
t.OnFileBeforeTokenRequest().Add(count[*core.FileTokenEvent](c, "OnFileBeforeTokenRequest"))
t.OnFileAfterTokenRequest().Add(count[*core.FileTokenEvent](c, "OnFileAfterTokenRequest"))
t.OnFileAfterTokenRequest().Add(count[*core.FileTokenEvent](c, "OnFileAfterTokenRequest"))
return c
}
func count[T any](c *Counter, name string) func(_ T) error {
return func(_ T) error {
c.Increment(name)
return nil
}
}

View File

@@ -19,6 +19,7 @@ func defaultTestData(t *testing.T, app core.App) {
adminTestData(t, app) adminTestData(t, app)
userTestData(t, app) userTestData(t, app)
ticketTestData(t, app)
reactionTestData(t, app) reactionTestData(t, app)
} }
@@ -57,6 +58,30 @@ func userTestData(t *testing.T, app core.App) {
} }
} }
func ticketTestData(t *testing.T, app core.App) {
t.Helper()
collection, err := app.Dao().FindCollectionByNameOrId(migrations.TicketCollectionName)
if err != nil {
t.Fatal(err)
}
record := models.NewRecord(collection)
record.SetId("t_test")
record.Set("name", "Test Ticket")
record.Set("type", "incident")
record.Set("description", "This is a test ticket.")
record.Set("open", true)
record.Set("schema", `{"type":"object","properties":{"tlp":{"title":"TLP","type":"string"}}}`)
record.Set("state", `{"tlp":"AMBER"}`)
record.Set("owner", "u_bob_analyst")
if err := app.Dao().SaveRecord(record); err != nil {
t.Fatal(err)
}
}
func reactionTestData(t *testing.T, app core.App) { func reactionTestData(t *testing.T, app core.App) {
t.Helper() t.Helper()
@@ -69,9 +94,9 @@ func reactionTestData(t *testing.T, app core.App) {
record.SetId("r_reaction") record.SetId("r_reaction")
record.Set("name", "Reaction") record.Set("name", "Reaction")
record.Set("trigger", "webhook") record.Set("trigger", "webhook")
record.Set("triggerdata", `{"path":"test"}`) record.Set("triggerdata", `{"token":"1234567890","path":"test"}`)
record.Set("action", "python") record.Set("action", "python")
record.Set("actiondata", `{"bootstrap":"requests","script":"print('Hello, World!')"}`) record.Set("actiondata", `{"requirements":"requests","script":"print('Hello, World!')"}`)
if err := app.Dao().SaveRecord(record); err != nil { if err := app.Dao().SaveRecord(record); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -95,7 +120,7 @@ func reactionTestData(t *testing.T, app core.App) {
record.Set("trigger", "hook") record.Set("trigger", "hook")
record.Set("triggerdata", `{"collections":["tickets"],"events":["create"]}`) record.Set("triggerdata", `{"collections":["tickets"],"events":["create"]}`)
record.Set("action", "python") record.Set("action", "python")
record.Set("actiondata", `{"bootstrap":"requests","script":"import requests\nrequests.post('http://127.0.0.1:12346/test', json={'test':True})"}`) record.Set("actiondata", `{"requirements":"requests","script":"import requests\nrequests.post('http://127.0.0.1:12346/test', json={'test':True})"}`)
if err := app.Dao().SaveRecord(record); err != nil { if err := app.Dao().SaveRecord(record); err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -3,162 +3,95 @@ package testing
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"os" "fmt"
"net/http/httptest"
"testing" "testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests" "github.com/stretchr/testify/assert"
"github.com/pocketbase/pocketbase/tokens" "github.com/stretchr/testify/require"
"github.com/SecurityBrewery/catalyst/app"
"github.com/SecurityBrewery/catalyst/migrations"
) )
func BaseApp(t *testing.T) (core.App, string, string, func()) {
t.Helper()
temp, err := os.MkdirTemp("", "catalyst_test_data")
if err != nil {
t.Fatal(err)
}
baseApp := app.App(temp)
if err := app.Bootstrap(baseApp); err != nil {
t.Fatal(err)
}
defaultTestData(t, baseApp)
adminToken, err := generateAdminToken(t, baseApp, adminEmail)
if err != nil {
t.Fatal(err)
}
analystToken, err := generateRecordToken(t, baseApp, analystEmail)
if err != nil {
t.Fatal(err)
}
return baseApp, adminToken, analystToken, func() { _ = os.RemoveAll(temp) }
}
func AppFactory(baseApp core.App) func(t *testing.T) *tests.TestApp {
return func(t *testing.T) *tests.TestApp {
t.Helper()
testApp, err := tests.NewTestApp(baseApp.DataDir())
if err != nil {
t.Fatal(err)
}
app.BindHooks(testApp)
if err := app.Bootstrap(testApp); err != nil {
t.Fatal(err)
}
return testApp
}
}
func App(t *testing.T) (*tests.TestApp, func()) {
t.Helper()
baseApp, _, _, cleanup := BaseApp(t)
testApp := AppFactory(baseApp)(t)
return testApp, cleanup
}
func generateAdminToken(t *testing.T, baseApp core.App, email string) (string, error) {
t.Helper()
app, err := tests.NewTestApp(baseApp.DataDir())
if err != nil {
return "", err
}
defer app.Cleanup()
admin, err := app.Dao().FindAdminByEmail(email)
if err != nil {
return "", err
}
return tokens.NewAdminAuthToken(app, admin)
}
func generateRecordToken(t *testing.T, baseApp core.App, email string) (string, error) {
t.Helper()
app, err := tests.NewTestApp(baseApp.DataDir())
if err != nil {
t.Fatal(err)
}
defer app.Cleanup()
record, err := app.Dao().FindAuthRecordByEmail(migrations.UserCollectionName, email)
if err != nil {
return "", err
}
return tokens.NewRecordAuthToken(app, record)
}
type BaseTest struct { type BaseTest struct {
Name string Name string
Method string Method string
RequestHeaders map[string]string RequestHeaders map[string]string
URL string URL string
Body string Body string
TestAppFactory func(t *testing.T) *tests.TestApp
} }
type AuthBasedExpectation struct { type UserTest struct {
Name string Name string
RequestHeaders map[string]string AuthRecord string
Admin string
ExpectedStatus int ExpectedStatus int
ExpectedContent []string ExpectedContent []string
NotExpectedContent []string NotExpectedContent []string
ExpectedEvents map[string]int ExpectedEvents map[string]int
} }
type authMatrixText struct { type catalystTest struct {
baseTest BaseTest baseTest BaseTest
authBasedExpectations []AuthBasedExpectation userTests []UserTest
} }
func mergeScenario(base BaseTest, expectation AuthBasedExpectation) tests.ApiScenario { func runMatrixTest(t *testing.T, baseTest BaseTest, userTest UserTest) {
return tests.ApiScenario{ t.Helper()
Name: expectation.Name,
Method: base.Method,
Url: base.URL,
Body: bytes.NewBufferString(base.Body),
TestAppFactory: base.TestAppFactory,
RequestHeaders: mergeMaps(base.RequestHeaders, expectation.RequestHeaders), baseApp, counter, baseAppCleanup := App(t)
ExpectedStatus: expectation.ExpectedStatus, defer baseAppCleanup()
ExpectedContent: expectation.ExpectedContent,
NotExpectedContent: expectation.NotExpectedContent, server, err := apis.InitApi(baseApp)
ExpectedEvents: expectation.ExpectedEvents, require.NoError(t, err)
}
if err := baseApp.OnBeforeServe().Trigger(&core.ServeEvent{
App: baseApp,
Router: server,
}); err != nil {
t.Fatal(fmt.Errorf("failed to trigger OnBeforeServe: %w", err))
} }
func mergeMaps(a, b map[string]string) map[string]string { recorder := httptest.NewRecorder()
if a == nil { body := bytes.NewBufferString(baseTest.Body)
return b req := httptest.NewRequest(baseTest.Method, baseTest.URL, body)
for k, v := range baseTest.RequestHeaders {
req.Header.Set(k, v)
} }
if b == nil { if userTest.AuthRecord != "" {
return a token, err := generateRecordToken(t, baseApp, userTest.AuthRecord)
require.NoError(t, err)
req.Header.Set("Authorization", token)
} }
for k, v := range b { if userTest.Admin != "" {
a[k] = v token, err := generateAdminToken(t, baseApp, userTest.Admin)
require.NoError(t, err)
req.Header.Set("Authorization", token)
} }
return a server.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, userTest.ExpectedStatus, res.StatusCode)
for _, expectedContent := range userTest.ExpectedContent {
assert.Contains(t, recorder.Body.String(), expectedContent)
}
for _, notExpectedContent := range userTest.NotExpectedContent {
assert.NotContains(t, recorder.Body.String(), notExpectedContent)
}
for event, count := range userTest.ExpectedEvents {
assert.Equal(t, count, counter.Count(event))
}
} }
func b(data map[string]any) []byte { func b(data map[string]any) []byte {

View File

@@ -4,13 +4,17 @@ import DeleteDialog from '@/components/common/DeleteDialog.vue'
import ReactionForm from '@/components/reaction/ReactionForm.vue' import ReactionForm from '@/components/reaction/ReactionForm.vue'
import { ScrollArea } from '@/components/ui/scroll-area' import { ScrollArea } from '@/components/ui/scroll-area'
import { Separator } from '@/components/ui/separator' import { Separator } from '@/components/ui/separator'
import { toast } from '@/components/ui/toast'
import { useMutation, useQuery, useQueryClient } from '@tanstack/vue-query' import { useMutation, useQuery, useQueryClient } from '@tanstack/vue-query'
import { onMounted, onUnmounted } from 'vue'
import { useRouter } from 'vue-router'
import { pb } from '@/lib/pocketbase' import { pb } from '@/lib/pocketbase'
import type { Reaction } from '@/lib/types' import type { Reaction } from '@/lib/types'
import { handleError } from '@/lib/utils' import { handleError } from '@/lib/utils'
const router = useRouter()
const queryClient = useQueryClient() const queryClient = useQueryClient()
const props = defineProps<{ const props = defineProps<{
@@ -32,6 +36,35 @@ const updateReactionMutation = useMutation({
onSuccess: () => queryClient.invalidateQueries({ queryKey: ['reactions'] }), onSuccess: () => queryClient.invalidateQueries({ queryKey: ['reactions'] }),
onError: handleError onError: handleError
}) })
onMounted(() => {
pb.collection('reactions').subscribe(props.id, (data) => {
if (data.action === 'delete') {
toast({
title: 'Reaction deleted',
description: 'The reaction has been deleted.',
variant: 'destructive'
})
router.push({ name: 'reactions' })
return
}
if (data.action === 'update') {
toast({
title: 'Reaction updated',
description: 'The reaction has been updated.'
})
queryClient.invalidateQueries({ queryKey: ['reactions', props.id] })
}
})
})
onUnmounted(() => {
pb.collection('reactions').unsubscribe(props.id)
})
</script> </script>
<template> <template>
@@ -54,7 +87,7 @@ const updateReactionMutation = useMutation({
<ScrollArea v-if="reaction" class="flex-1"> <ScrollArea v-if="reaction" class="flex-1">
<div class="flex max-w-[640px] flex-col gap-4 p-4"> <div class="flex max-w-[640px] flex-col gap-4 p-4">
<ReactionForm :reaction="reaction" @submit="updateReactionMutation.mutate" hide-cancel /> <ReactionForm :reaction="reaction" @submit="updateReactionMutation.mutate" />
</div> </div>
</ScrollArea> </ScrollArea>
</div> </div>

View File

@@ -166,6 +166,8 @@ watch(
() => { () => {
if (equalReaction(values, props.reaction)) { if (equalReaction(values, props.reaction)) {
submitDisabledReason.value = 'Make changes to save' submitDisabledReason.value = 'Make changes to save'
} else {
submitDisabledReason.value = ''
} }
}, },
{ immediate: true } { immediate: true }
@@ -312,7 +314,7 @@ const curlExample = computed(() => {
</TooltipContent> </TooltipContent>
</Tooltip> </Tooltip>
</TooltipProvider> </TooltipProvider>
<slot name="cancel" /> <slot name="cancel"></slot>
</div> </div>
</form> </form>
</template> </template>

View File

@@ -1,16 +1,19 @@
<script setup lang="ts"> <script setup lang="ts">
import TanView from '@/components/TanView.vue' import TanView from '@/components/TanView.vue'
import ResourceListElement from '@/components/common/ResourceListElement.vue' import ResourceListElement from '@/components/common/ResourceListElement.vue'
import ReactionNewDialog from '@/components/reaction/ReactionNewDialog.vue' import { Button } from '@/components/ui/button'
import { Separator } from '@/components/ui/separator' import { Separator } from '@/components/ui/separator'
import { useQuery } from '@tanstack/vue-query' import { useQuery, useQueryClient } from '@tanstack/vue-query'
import { useRoute } from 'vue-router' import { onMounted, onUnmounted } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { pb } from '@/lib/pocketbase' import { pb } from '@/lib/pocketbase'
import type { Reaction } from '@/lib/types' import type { Reaction } from '@/lib/types'
const route = useRoute() const route = useRoute()
const router = useRouter()
const queryClient = useQueryClient()
const { const {
isPending, isPending,
@@ -47,6 +50,20 @@ const reactionNiceName = (reaction: Reaction) => {
return 'Unknown' return 'Unknown'
} }
} }
const openNew = () => {
router.push({ name: 'reactions', params: { id: 'new' } })
}
onMounted(() => {
pb.collection('reactions').subscribe('*', () => {
queryClient.invalidateQueries({ queryKey: ['reactions'] })
})
})
onUnmounted(() => {
pb.collection('reactions').unsubscribe('*')
})
</script> </script>
<template> <template>
@@ -55,7 +72,7 @@ const reactionNiceName = (reaction: Reaction) => {
<div class="flex items-center bg-background px-4 py-2"> <div class="flex items-center bg-background px-4 py-2">
<h1 class="text-xl font-bold">Reactions</h1> <h1 class="text-xl font-bold">Reactions</h1>
<div class="ml-auto"> <div class="ml-auto">
<ReactionNewDialog /> <Button variant="ghost" @click="openNew"> New Reaction </Button>
</div> </div>
</div> </div>
<Separator /> <Separator />

View File

@@ -0,0 +1,37 @@
<script setup lang="ts">
import ReactionForm from '@/components/reaction/ReactionForm.vue'
import { ScrollArea } from '@/components/ui/scroll-area'
import { Separator } from '@/components/ui/separator'
import { useMutation, useQueryClient } from '@tanstack/vue-query'
import { useRouter } from 'vue-router'
import { pb } from '@/lib/pocketbase'
import type { Reaction, Ticket } from '@/lib/types'
import { handleError } from '@/lib/utils'
const queryClient = useQueryClient()
const router = useRouter()
const addReactionMutation = useMutation({
mutationFn: (values: Reaction): Promise<Reaction> => pb.collection('reactions').create(values),
onSuccess: (data: Ticket) => {
router.push({ name: 'reactions', params: { id: data.id } })
queryClient.invalidateQueries({ queryKey: ['reactions'] })
},
onError: handleError
})
</script>
<template>
<div class="flex h-full flex-1 flex-col overflow-hidden">
<div class="flex min-h-14 items-center bg-background px-4 py-2"></div>
<Separator />
<ScrollArea class="flex-1">
<div class="flex max-w-[640px] flex-col gap-4 p-4">
<ReactionForm @submit="addReactionMutation.mutate" />
</div>
</ScrollArea>
</div>
</template>

View File

@@ -1,63 +0,0 @@
<script setup lang="ts">
import ReactionForm from '@/components/reaction/ReactionForm.vue'
import { Button } from '@/components/ui/button'
import {
Dialog,
DialogClose,
DialogContent,
DialogDescription,
DialogHeader,
DialogScrollContent,
DialogTitle,
DialogTrigger
} from '@/components/ui/dialog'
import { useMutation, useQueryClient } from '@tanstack/vue-query'
import { ref } from 'vue'
import { useRouter } from 'vue-router'
import { pb } from '@/lib/pocketbase'
import type { Reaction, Ticket } from '@/lib/types'
import { handleError } from '@/lib/utils'
const queryClient = useQueryClient()
const router = useRouter()
const isOpen = ref(false)
const addReactionMutation = useMutation({
mutationFn: (values: Reaction): Promise<Reaction> => pb.collection('reactions').create(values),
onSuccess: (data: Ticket) => {
router.push({ name: 'reactions', params: { id: data.id } })
queryClient.invalidateQueries({ queryKey: ['reactions'] })
isOpen.value = false
},
onError: handleError
})
const cancel = () => (isOpen.value = false)
</script>
<template>
<Dialog v-model:open="isOpen">
<DialogTrigger as-child>
<Button variant="ghost">New Reaction</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>New Reaction</DialogTitle>
<DialogDescription>Create a new reaction</DialogDescription>
</DialogHeader>
<DialogScrollContent>
<ReactionForm @submit="addReactionMutation.mutate">
<template #cancel>
<DialogClose as-child>
<Button type="button" variant="secondary">Cancel</Button>
</DialogClose>
</template>
</ReactionForm>
</DialogScrollContent>
</DialogContent>
</Dialog>
</template>

View File

@@ -22,7 +22,7 @@ import { Tabs, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { Edit } from 'lucide-vue-next' import { Edit } from 'lucide-vue-next'
import { useMutation, useQuery, useQueryClient } from '@tanstack/vue-query' import { useMutation, useQuery, useQueryClient } from '@tanstack/vue-query'
import { computed, ref } from 'vue' import { computed, onMounted, onUnmounted, ref } from 'vue'
import { useRoute } from 'vue-router' import { useRoute } from 'vue-router'
import { pb } from '@/lib/pocketbase' import { pb } from '@/lib/pocketbase'

View File

@@ -19,7 +19,7 @@ defineProps<{
:key="item.id" :key="item.id"
:title="item.name" :title="item.name"
:created="item.created" :created="item.created"
:subtitle="item.expand.owner.name" :subtitle="item.expand.owner ? item.expand.owner.name : ''"
:description="item.description ? item.description.substring(0, 300) : ''" :description="item.description ? item.description.substring(0, 300) : ''"
:active="route.params.id === item.id" :active="route.params.id === item.id"
:to="`/tickets/${item.expand.type.id}/${item.id}`" :to="`/tickets/${item.expand.type.id}/${item.id}`"

View File

@@ -15,7 +15,7 @@ const queryClient = useQueryClient()
const props = defineProps<{ const props = defineProps<{
ticket: Ticket ticket: Ticket
uID: string uID?: string
}>() }>()
const { const {
@@ -25,7 +25,13 @@ const {
error error
} = useQuery({ } = useQuery({
queryKey: ['tickets', props.ticket.id, 'owner', props.uID], queryKey: ['tickets', props.ticket.id, 'owner', props.uID],
queryFn: (): Promise<User> => pb.collection('users').getOne(props.uID) queryFn: (): Promise<User | null> => {
if (!props.uID) {
return Promise.resolve(null)
}
return pb.collection('users').getOne(props.uID)
}
}) })
const setTicketOwnerMutation = useMutation({ const setTicketOwnerMutation = useMutation({
@@ -48,12 +54,12 @@ const update = (user: User) => setTicketOwnerMutation.mutate(user)
<AlertTitle>Error</AlertTitle> <AlertTitle>Error</AlertTitle>
<AlertDescription>{{ error }}</AlertDescription> <AlertDescription>{{ error }}</AlertDescription>
</Alert> </Alert>
<div v-if="!user"> <UserSelect v-if="!user" @update:modelValue="update">
<Button variant="outline" role="combobox" disabled> <Button variant="outline" role="combobox">
<User2 class="mr-2 size-4 h-4 w-4 shrink-0 opacity-50" /> <User2 class="mr-2 size-4 h-4 w-4 shrink-0 opacity-50" />
{{ props.uID }} Unassigned
</Button> </Button>
</div> </UserSelect>
<UserSelect v-else :modelValue="user" @update:modelValue="update"> <UserSelect v-else :modelValue="user" @update:modelValue="update">
<Button variant="outline" role="combobox"> <Button variant="outline" role="combobox">
<User2 class="mr-2 size-4 h-4 w-4 shrink-0 opacity-50" /> <User2 class="mr-2 size-4 h-4 w-4 shrink-0 opacity-50" />

View File

@@ -20,7 +20,7 @@ const queryClient = useQueryClient()
const props = defineProps<{ const props = defineProps<{
ticket: Ticket ticket: Ticket
tasks: Array<Task> tasks?: Array<Task>
}>() }>()
const setTaskOwnerMutation = useMutation({ const setTaskOwnerMutation = useMutation({

View File

@@ -12,7 +12,7 @@ import type { Ticket, TimelineItem } from '@/lib/types'
const props = defineProps<{ const props = defineProps<{
ticket: Ticket ticket: Ticket
timeline: Array<TimelineItem> timeline?: Array<TimelineItem>
}>() }>()
const commentsByDate: ComputedRef<Record<string, Array<TimelineItem>>> = computed(() => { const commentsByDate: ComputedRef<Record<string, Array<TimelineItem>>> = computed(() => {
@@ -41,7 +41,7 @@ const commentsByDate: ComputedRef<Record<string, Array<TimelineItem>>> = compute
<template> <template>
<div class="mt-2 flex flex-col gap-2"> <div class="mt-2 flex flex-col gap-2">
<Card <Card
v-if="!props.timeline || props.timeline.length === 0" v-if="!timeline || timeline.length === 0"
class="flex h-10 items-center p-4 text-muted-foreground" class="flex h-10 items-center p-4 text-muted-foreground"
> >
No timeline entries added yet. No timeline entries added yet.
@@ -61,6 +61,6 @@ const commentsByDate: ComputedRef<Record<string, Array<TimelineItem>>> = compute
</Card> </Card>
</div> </div>
</div> </div>
<TicketTimelineInput :ticket="props.ticket" class="w-full" /> <TicketTimelineInput :ticket="ticket" class="w-full" />
</div> </div>
</template> </template>

View File

@@ -47,7 +47,7 @@ onMounted(() => {
<template> <template>
<TwoColumn> <TwoColumn>
<div class="flex h-screen flex-1 flex-col"> <div class="flex h-screen flex-1 flex-col">
<div class="flex h-14 items-center bg-background px-4 py-2"> <div class="flex h-14 min-h-14 items-center bg-background px-4 py-2">
<h1 class="text-xl font-bold">Dashboard</h1> <h1 class="text-xl font-bold">Dashboard</h1>
</div> </div>
<Separator class="shrink-0" /> <Separator class="shrink-0" />

View File

@@ -2,6 +2,7 @@
import ThreeColumn from '@/components/layout/ThreeColumn.vue' import ThreeColumn from '@/components/layout/ThreeColumn.vue'
import ReactionDisplay from '@/components/reaction/ReactionDisplay.vue' import ReactionDisplay from '@/components/reaction/ReactionDisplay.vue'
import ReactionList from '@/components/reaction/ReactionList.vue' import ReactionList from '@/components/reaction/ReactionList.vue'
import ReactionNew from '@/components/reaction/ReactionNew.vue'
import { computed, onMounted } from 'vue' import { computed, onMounted } from 'vue'
import { useRoute, useRouter } from 'vue-router' import { useRoute, useRouter } from 'vue-router'
@@ -29,6 +30,7 @@ onMounted(() => {
<div v-if="!id" class="flex h-full w-full items-center justify-center text-lg text-gray-500"> <div v-if="!id" class="flex h-full w-full items-center justify-center text-lg text-gray-500">
No reaction selected No reaction selected
</div> </div>
<ReactionNew v-else-if="id === 'new'" key="new" />
<ReactionDisplay v-else :key="id" :id="id" /> <ReactionDisplay v-else :key="id" :id="id" />
</template> </template>
</ThreeColumn> </ThreeColumn>

View File

@@ -1,4 +1,4 @@
package ui package ui_test
import ( import (
"io/fs" "io/fs"
@@ -6,9 +6,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/SecurityBrewery/catalyst/ui"
) )
func TestUI(t *testing.T) { func TestUI(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
wantFiles []string wantFiles []string
@@ -22,7 +26,9 @@ func TestUI(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := UI() t.Parallel()
got := ui.UI()
var gotFiles []string var gotFiles []string