mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2026-02-01 10:53:30 +01:00
feat: improve python actions (#1083)
This commit is contained in:
@@ -4,17 +4,33 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tokens"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/migrations"
|
||||
"github.com/SecurityBrewery/catalyst/reaction/action/python"
|
||||
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, actionName, actionData, payload string) ([]byte, error) {
|
||||
func Run(ctx context.Context, app core.App, actionName, actionData, payload string) ([]byte, error) {
|
||||
action, err := decode(actionName, actionData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a, ok := action.(authenticatedAction); ok {
|
||||
token, err := systemToken(app)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get system token: %w", err)
|
||||
}
|
||||
|
||||
a.SetToken(token)
|
||||
}
|
||||
|
||||
return action.Run(ctx, payload)
|
||||
}
|
||||
|
||||
@@ -22,6 +38,10 @@ type action interface {
|
||||
Run(ctx context.Context, payload string) ([]byte, error)
|
||||
}
|
||||
|
||||
type authenticatedAction interface {
|
||||
SetToken(token string)
|
||||
}
|
||||
|
||||
func decode(actionName, actionData string) (action, error) {
|
||||
switch actionName {
|
||||
case "python":
|
||||
@@ -42,3 +62,20 @@ func decode(actionName, actionData string) (action, error) {
|
||||
return nil, fmt.Errorf("action %q not found", actionName)
|
||||
}
|
||||
}
|
||||
|
||||
func systemToken(app core.App) (string, error) {
|
||||
authRecord, err := app.Dao().FindAuthRecordByUsername(migrations.UserCollectionName, migrations.SystemUserID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to find system auth record: %w", err)
|
||||
}
|
||||
|
||||
return security.NewJWT(
|
||||
jwt.MapClaims{
|
||||
"id": authRecord.Id,
|
||||
"type": tokens.TypeAuthRecord,
|
||||
"collectionId": authRecord.Collection().Id,
|
||||
},
|
||||
authRecord.TokenKey()+app.Settings().RecordAuthToken.Secret,
|
||||
int64(time.Second*60),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,14 @@ import (
|
||||
)
|
||||
|
||||
type Python struct {
|
||||
Bootstrap string `json:"bootstrap"`
|
||||
Script string `json:"script"`
|
||||
Requirements string `json:"requirements"`
|
||||
Script string `json:"script"`
|
||||
|
||||
token string
|
||||
}
|
||||
|
||||
func (a *Python) SetToken(token string) {
|
||||
a.token = token
|
||||
}
|
||||
|
||||
func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
|
||||
@@ -22,7 +28,8 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
|
||||
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
if b, err := pythonSetup(ctx, tempDir); err != nil {
|
||||
b, err := pythonSetup(ctx, tempDir)
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
b = append(b, ee.Stderr...)
|
||||
@@ -31,16 +38,17 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b))
|
||||
}
|
||||
|
||||
if b, err := pythonRunBootstrap(ctx, tempDir, a.Bootstrap); err != nil {
|
||||
b, err = a.pythonInstallRequirements(ctx, tempDir)
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
b = append(b, ee.Stderr...)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to run bootstrap, %w: %s", err, string(b))
|
||||
return nil, fmt.Errorf("failed to run install requirements, %w: %s", err, string(b))
|
||||
}
|
||||
|
||||
b, err := pythonRunScript(ctx, tempDir, a.Script, payload)
|
||||
b, err = a.pythonRunScript(ctx, tempDir, payload)
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
@@ -63,35 +71,42 @@ func pythonSetup(ctx context.Context, tempDir string) ([]byte, error) {
|
||||
return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output()
|
||||
}
|
||||
|
||||
func pythonRunBootstrap(ctx context.Context, tempDir, bootstrap string) ([]byte, error) {
|
||||
hasBootstrap := len(strings.TrimSpace(bootstrap)) > 0
|
||||
func (a *Python) pythonInstallRequirements(ctx context.Context, tempDir string) ([]byte, error) {
|
||||
hasRequirements := len(strings.TrimSpace(a.Requirements)) > 0
|
||||
|
||||
if !hasBootstrap {
|
||||
if !hasRequirements {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
bootstrapPath := tempDir + "/requirements.txt"
|
||||
requirementsPath := tempDir + "/requirements.txt"
|
||||
|
||||
if err := os.WriteFile(bootstrapPath, []byte(bootstrap), 0o600); err != nil {
|
||||
if err := os.WriteFile(requirementsPath, []byte(a.Requirements), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// install dependencies
|
||||
pipPath := tempDir + "/venv/bin/pip"
|
||||
|
||||
return exec.CommandContext(ctx, pipPath, "install", "-r", bootstrapPath).Output()
|
||||
return exec.CommandContext(ctx, pipPath, "install", "-r", requirementsPath).Output()
|
||||
}
|
||||
|
||||
func pythonRunScript(ctx context.Context, tempDir, script, payload string) ([]byte, error) {
|
||||
func (a *Python) pythonRunScript(ctx context.Context, tempDir, payload string) ([]byte, error) {
|
||||
scriptPath := tempDir + "/script.py"
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
|
||||
if err := os.WriteFile(scriptPath, []byte(a.Script), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pythonPath := tempDir + "/venv/bin/python"
|
||||
|
||||
return exec.CommandContext(ctx, pythonPath, scriptPath, payload).Output()
|
||||
cmd := exec.CommandContext(ctx, pythonPath, scriptPath, payload)
|
||||
|
||||
cmd.Env = []string{}
|
||||
if a.token != "" {
|
||||
cmd.Env = append(cmd.Env, "CATALYST_TOKEN="+a.token)
|
||||
}
|
||||
|
||||
return cmd.Output()
|
||||
}
|
||||
|
||||
func findExec(name ...string) (string, error) {
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
package python
|
||||
package python_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/reaction/action/python"
|
||||
)
|
||||
|
||||
func TestPython_Run(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type fields struct {
|
||||
Bootstrap string
|
||||
Script string
|
||||
Requirements string
|
||||
Script string
|
||||
}
|
||||
|
||||
type args struct {
|
||||
@@ -68,14 +72,28 @@ func TestPython_Run(t *testing.T) {
|
||||
want: nil,
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "requests",
|
||||
fields: fields{
|
||||
Requirements: "requests",
|
||||
Script: "import requests\nprint(requests.get('https://xkcd.com/2961/info.0.json').text)",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: []byte("{\"month\": \"7\", \"num\": 2961, \"link\": \"\", \"year\": \"2024\", \"news\": \"\", \"safe_title\": \"CrowdStrike\", \"transcript\": \"\", \"alt\": \"We were going to try swordfighting, but all my compiling is on hold.\", \"img\": \"https://imgs.xkcd.com/comics/crowdstrike.png\", \"title\": \"CrowdStrike\", \"day\": \"19\"}\n"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
a := &Python{
|
||||
Bootstrap: tt.fields.Bootstrap,
|
||||
Script: tt.fields.Script,
|
||||
a := &python.Python{
|
||||
Requirements: tt.fields.Requirements,
|
||||
Script: tt.fields.Script,
|
||||
}
|
||||
got, err := a.Run(ctx, tt.args.payload)
|
||||
tt.wantErr(t, err)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package webhook
|
||||
package webhook_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
|
||||
)
|
||||
|
||||
func TestEncodeBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
requestBody io.Reader
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func TestEncodeBody(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := EncodeBody(tt.args.requestBody)
|
||||
t.Parallel()
|
||||
|
||||
got, got1 := webhook.EncodeBody(tt.args.requestBody)
|
||||
if got != tt.want {
|
||||
t.Errorf("EncodeBody() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
@@ -16,11 +16,15 @@ import (
|
||||
)
|
||||
|
||||
func TestWebhook_Run(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := catalystTesting.NewRecordingServer()
|
||||
|
||||
go http.ListenAndServe("127.0.0.1:12347", server) //nolint:gosec,errcheck
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
if err := catalystTesting.WaitForStatus("http://127.0.0.1:12347/health", http.StatusOK, 5*time.Second); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Headers map[string]string
|
||||
@@ -50,10 +54,10 @@ func TestWebhook_Run(t *testing.T) {
|
||||
want: map[string]any{
|
||||
"statusCode": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Length": []any{"13"},
|
||||
"Content-Type": []any{"text/plain; charset=utf-8"},
|
||||
"Content-Length": []any{"14"},
|
||||
"Content-Type": []any{"application/json; charset=UTF-8"},
|
||||
},
|
||||
"body": `{"test":true}`,
|
||||
"body": "{\"test\":true}\n",
|
||||
"isBase64Encoded": false,
|
||||
},
|
||||
wantErr: assert.NoError,
|
||||
@@ -61,6 +65,8 @@ func TestWebhook_Run(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
a := &webhook.Webhook{
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package reaction
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/reaction/trigger/hook"
|
||||
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
|
||||
)
|
||||
|
||||
func BindHooks(app core.App) {
|
||||
hook.BindHooks(app)
|
||||
webhook.BindHooks(app)
|
||||
func BindHooks(pb *pocketbase.PocketBase, test bool) {
|
||||
hook.BindHooks(pb, test)
|
||||
webhook.BindHooks(pb)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
@@ -22,43 +25,40 @@ type Hook struct {
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
func BindHooks(app core.App) {
|
||||
app.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error {
|
||||
if err := hook(app.Dao(), "create", e.Collection.Name, e.Record, e.HttpContext); err != nil {
|
||||
app.Logger().Error("failed to find hook reaction", "error", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
func BindHooks(pb *pocketbase.PocketBase, test bool) {
|
||||
pb.App.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error {
|
||||
return hook(e.HttpContext, pb.App, "create", e.Collection.Name, e.Record, test)
|
||||
})
|
||||
app.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error {
|
||||
if err := hook(app.Dao(), "update", e.Collection.Name, e.Record, e.HttpContext); err != nil {
|
||||
app.Logger().Error("failed to find hook reaction", "error", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
pb.App.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error {
|
||||
return hook(e.HttpContext, pb.App, "update", e.Collection.Name, e.Record, test)
|
||||
})
|
||||
app.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error {
|
||||
if err := hook(app.Dao(), "delete", e.Collection.Name, e.Record, e.HttpContext); err != nil {
|
||||
app.Logger().Error("failed to find hook reaction", "error", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
pb.App.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error {
|
||||
return hook(e.HttpContext, pb.App, "delete", e.Collection.Name, e.Record, test)
|
||||
})
|
||||
}
|
||||
|
||||
func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx echo.Context) error {
|
||||
func hook(ctx echo.Context, app core.App, event, collection string, record *models.Record, test bool) error {
|
||||
auth, _ := ctx.Get(apis.ContextAuthRecordKey).(*models.Record)
|
||||
admin, _ := ctx.Get(apis.ContextAdminKey).(*models.Admin)
|
||||
|
||||
hook, found, err := findByHookTrigger(dao, collection, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find hook reaction: %w", err)
|
||||
if !test {
|
||||
go mustRunHook(app, collection, event, record, auth, admin)
|
||||
} else {
|
||||
mustRunHook(app, collection, event, record, auth, admin)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustRunHook(app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := runHook(ctx, app, collection, event, record, auth, admin); err != nil {
|
||||
slog.ErrorContext(ctx, fmt.Sprintf("failed to run hook reaction: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func runHook(ctx context.Context, app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) error {
|
||||
payload, err := json.Marshal(&webhook.Payload{
|
||||
Action: event,
|
||||
Collection: collection,
|
||||
@@ -67,10 +67,19 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec
|
||||
Admin: admin,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = action.Run(ctx.Request().Context(), hook.GetString("action"), hook.GetString("actiondata"), string(payload))
|
||||
hook, found, err := findByHookTrigger(app.Dao(), collection, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find hook by trigger: %w", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = action.Run(ctx, app, hook.GetString("action"), hook.GetString("actiondata"), string(payload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run hook reaction: %w", err)
|
||||
}
|
||||
@@ -81,7 +90,7 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec
|
||||
func findByHookTrigger(dao *daos.Dao, collection, event string) (*models.Record, bool, error) {
|
||||
records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "hook"})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, false, fmt.Errorf("failed to find hook reaction: %w", err)
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
|
||||
@@ -15,8 +15,8 @@ type Request struct {
|
||||
IsBase64Encoded bool `json:"isBase64Encoded"`
|
||||
}
|
||||
|
||||
// isJSON checks if the data is JSON.
|
||||
func isJSON(data []byte) bool {
|
||||
// IsJSON checks if the data is JSON.
|
||||
func IsJSON(data []byte) bool {
|
||||
var msg json.RawMessage
|
||||
|
||||
return json.Unmarshal(data, &msg) == nil
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package webhook
|
||||
package webhook_test
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
|
||||
)
|
||||
|
||||
func Test_isJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
@@ -29,7 +35,9 @@ func Test_isJSON(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isJSON(tt.args.data); got != tt.want {
|
||||
t.Parallel()
|
||||
|
||||
if got := webhook.IsJSON(tt.args.data); got != tt.want {
|
||||
t.Errorf("isJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
@@ -26,22 +27,22 @@ type Webhook struct {
|
||||
|
||||
const prefix = "/reaction/"
|
||||
|
||||
func BindHooks(app core.App) {
|
||||
app.OnBeforeServe().Add(func(e *core.ServeEvent) error {
|
||||
e.Router.Any(prefix+"*", handle(e.App.Dao()))
|
||||
func BindHooks(pb *pocketbase.PocketBase) {
|
||||
pb.OnBeforeServe().Add(func(e *core.ServeEvent) error {
|
||||
e.Router.Any(prefix+"*", handle(e.App))
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func handle(dao *daos.Dao) func(c echo.Context) error {
|
||||
func handle(app core.App) func(c echo.Context) error {
|
||||
return func(c echo.Context) error {
|
||||
record, payload, apiErr := parseRequest(dao, c.Request())
|
||||
record, payload, apiErr := parseRequest(app.Dao(), c.Request())
|
||||
if apiErr != nil {
|
||||
return apiErr
|
||||
}
|
||||
|
||||
output, err := action.Run(c.Request().Context(), record.GetString("action"), record.GetString("actiondata"), string(payload))
|
||||
output, err := action.Run(c.Request().Context(), app, record.GetString("action"), record.GetString("actiondata"), string(payload))
|
||||
if err != nil {
|
||||
return apis.NewApiError(http.StatusInternalServerError, err.Error(), nil)
|
||||
}
|
||||
@@ -138,7 +139,7 @@ func writeOutput(c echo.Context, output []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
if isJSON(output) {
|
||||
if IsJSON(output) {
|
||||
return c.JSON(http.StatusOK, json.RawMessage(output))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user