feat: improve python actions (#1083)

This commit is contained in:
Jonas Plum
2024-07-21 02:56:43 +02:00
committed by GitHub
parent 81bfbb2072
commit 91429effe2
55 changed files with 1143 additions and 585 deletions

View File

@@ -4,17 +4,33 @@ import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/SecurityBrewery/catalyst/migrations"
"github.com/SecurityBrewery/catalyst/reaction/action/python"
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
)
func Run(ctx context.Context, actionName, actionData, payload string) ([]byte, error) {
func Run(ctx context.Context, app core.App, actionName, actionData, payload string) ([]byte, error) {
action, err := decode(actionName, actionData)
if err != nil {
return nil, err
}
if a, ok := action.(authenticatedAction); ok {
token, err := systemToken(app)
if err != nil {
return nil, fmt.Errorf("failed to get system token: %w", err)
}
a.SetToken(token)
}
return action.Run(ctx, payload)
}
@@ -22,6 +38,10 @@ type action interface {
Run(ctx context.Context, payload string) ([]byte, error)
}
type authenticatedAction interface {
SetToken(token string)
}
func decode(actionName, actionData string) (action, error) {
switch actionName {
case "python":
@@ -42,3 +62,20 @@ func decode(actionName, actionData string) (action, error) {
return nil, fmt.Errorf("action %q not found", actionName)
}
}
func systemToken(app core.App) (string, error) {
authRecord, err := app.Dao().FindAuthRecordByUsername(migrations.UserCollectionName, migrations.SystemUserID)
if err != nil {
return "", fmt.Errorf("failed to find system auth record: %w", err)
}
return security.NewJWT(
jwt.MapClaims{
"id": authRecord.Id,
"type": tokens.TypeAuthRecord,
"collectionId": authRecord.Collection().Id,
},
authRecord.TokenKey()+app.Settings().RecordAuthToken.Secret,
int64(time.Second*60),
)
}

View File

@@ -10,8 +10,14 @@ import (
)
type Python struct {
Bootstrap string `json:"bootstrap"`
Script string `json:"script"`
Requirements string `json:"requirements"`
Script string `json:"script"`
token string
}
func (a *Python) SetToken(token string) {
a.token = token
}
func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
@@ -22,7 +28,8 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
defer os.RemoveAll(tempDir)
if b, err := pythonSetup(ctx, tempDir); err != nil {
b, err := pythonSetup(ctx, tempDir)
if err != nil {
var ee *exec.ExitError
if errors.As(err, &ee) {
b = append(b, ee.Stderr...)
@@ -31,16 +38,17 @@ func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b))
}
if b, err := pythonRunBootstrap(ctx, tempDir, a.Bootstrap); err != nil {
b, err = a.pythonInstallRequirements(ctx, tempDir)
if err != nil {
var ee *exec.ExitError
if errors.As(err, &ee) {
b = append(b, ee.Stderr...)
}
return nil, fmt.Errorf("failed to run bootstrap, %w: %s", err, string(b))
return nil, fmt.Errorf("failed to run install requirements, %w: %s", err, string(b))
}
b, err := pythonRunScript(ctx, tempDir, a.Script, payload)
b, err = a.pythonRunScript(ctx, tempDir, payload)
if err != nil {
var ee *exec.ExitError
if errors.As(err, &ee) {
@@ -63,35 +71,42 @@ func pythonSetup(ctx context.Context, tempDir string) ([]byte, error) {
return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output()
}
func pythonRunBootstrap(ctx context.Context, tempDir, bootstrap string) ([]byte, error) {
hasBootstrap := len(strings.TrimSpace(bootstrap)) > 0
func (a *Python) pythonInstallRequirements(ctx context.Context, tempDir string) ([]byte, error) {
hasRequirements := len(strings.TrimSpace(a.Requirements)) > 0
if !hasBootstrap {
if !hasRequirements {
return nil, nil
}
bootstrapPath := tempDir + "/requirements.txt"
requirementsPath := tempDir + "/requirements.txt"
if err := os.WriteFile(bootstrapPath, []byte(bootstrap), 0o600); err != nil {
if err := os.WriteFile(requirementsPath, []byte(a.Requirements), 0o600); err != nil {
return nil, err
}
// install dependencies
pipPath := tempDir + "/venv/bin/pip"
return exec.CommandContext(ctx, pipPath, "install", "-r", bootstrapPath).Output()
return exec.CommandContext(ctx, pipPath, "install", "-r", requirementsPath).Output()
}
func pythonRunScript(ctx context.Context, tempDir, script, payload string) ([]byte, error) {
func (a *Python) pythonRunScript(ctx context.Context, tempDir, payload string) ([]byte, error) {
scriptPath := tempDir + "/script.py"
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
if err := os.WriteFile(scriptPath, []byte(a.Script), 0o600); err != nil {
return nil, err
}
pythonPath := tempDir + "/venv/bin/python"
return exec.CommandContext(ctx, pythonPath, scriptPath, payload).Output()
cmd := exec.CommandContext(ctx, pythonPath, scriptPath, payload)
cmd.Env = []string{}
if a.token != "" {
cmd.Env = append(cmd.Env, "CATALYST_TOKEN="+a.token)
}
return cmd.Output()
}
func findExec(name ...string) (string, error) {

View File

@@ -1,16 +1,20 @@
package python
package python_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/SecurityBrewery/catalyst/reaction/action/python"
)
func TestPython_Run(t *testing.T) {
t.Parallel()
type fields struct {
Bootstrap string
Script string
Requirements string
Script string
}
type args struct {
@@ -68,14 +72,28 @@ func TestPython_Run(t *testing.T) {
want: nil,
wantErr: assert.Error,
},
{
name: "requests",
fields: fields{
Requirements: "requests",
Script: "import requests\nprint(requests.get('https://xkcd.com/2961/info.0.json').text)",
},
args: args{
payload: "test",
},
want: []byte("{\"month\": \"7\", \"num\": 2961, \"link\": \"\", \"year\": \"2024\", \"news\": \"\", \"safe_title\": \"CrowdStrike\", \"transcript\": \"\", \"alt\": \"We were going to try swordfighting, but all my compiling is on hold.\", \"img\": \"https://imgs.xkcd.com/comics/crowdstrike.png\", \"title\": \"CrowdStrike\", \"day\": \"19\"}\n"),
wantErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
a := &Python{
Bootstrap: tt.fields.Bootstrap,
Script: tt.fields.Script,
a := &python.Python{
Requirements: tt.fields.Requirements,
Script: tt.fields.Script,
}
got, err := a.Run(ctx, tt.args.payload)
tt.wantErr(t, err)

View File

@@ -1,12 +1,16 @@
package webhook
package webhook_test
import (
"bytes"
"io"
"testing"
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
)
func TestEncodeBody(t *testing.T) {
t.Parallel()
type args struct {
requestBody io.Reader
}
@@ -36,7 +40,9 @@ func TestEncodeBody(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := EncodeBody(tt.args.requestBody)
t.Parallel()
got, got1 := webhook.EncodeBody(tt.args.requestBody)
if got != tt.want {
t.Errorf("EncodeBody() got = %v, want %v", got, tt.want)
}

View File

@@ -16,11 +16,15 @@ import (
)
func TestWebhook_Run(t *testing.T) {
t.Parallel()
server := catalystTesting.NewRecordingServer()
go http.ListenAndServe("127.0.0.1:12347", server) //nolint:gosec,errcheck
time.Sleep(1 * time.Second)
if err := catalystTesting.WaitForStatus("http://127.0.0.1:12347/health", http.StatusOK, 5*time.Second); err != nil {
t.Fatal(err)
}
type fields struct {
Headers map[string]string
@@ -50,10 +54,10 @@ func TestWebhook_Run(t *testing.T) {
want: map[string]any{
"statusCode": 200,
"headers": map[string]any{
"Content-Length": []any{"13"},
"Content-Type": []any{"text/plain; charset=utf-8"},
"Content-Length": []any{"14"},
"Content-Type": []any{"application/json; charset=UTF-8"},
},
"body": `{"test":true}`,
"body": "{\"test\":true}\n",
"isBase64Encoded": false,
},
wantErr: assert.NoError,
@@ -61,6 +65,8 @@ func TestWebhook_Run(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
a := &webhook.Webhook{

View File

@@ -1,13 +1,13 @@
package reaction
import (
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase"
"github.com/SecurityBrewery/catalyst/reaction/trigger/hook"
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
)
func BindHooks(app core.App) {
hook.BindHooks(app)
webhook.BindHooks(app)
func BindHooks(pb *pocketbase.PocketBase, test bool) {
hook.BindHooks(pb, test)
webhook.BindHooks(pb)
}

View File

@@ -1,12 +1,15 @@
package hook
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"slices"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
@@ -22,43 +25,40 @@ type Hook struct {
Events []string `json:"events"`
}
func BindHooks(app core.App) {
app.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error {
if err := hook(app.Dao(), "create", e.Collection.Name, e.Record, e.HttpContext); err != nil {
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
func BindHooks(pb *pocketbase.PocketBase, test bool) {
pb.App.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error {
return hook(e.HttpContext, pb.App, "create", e.Collection.Name, e.Record, test)
})
app.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error {
if err := hook(app.Dao(), "update", e.Collection.Name, e.Record, e.HttpContext); err != nil {
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
pb.App.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error {
return hook(e.HttpContext, pb.App, "update", e.Collection.Name, e.Record, test)
})
app.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error {
if err := hook(app.Dao(), "delete", e.Collection.Name, e.Record, e.HttpContext); err != nil {
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
pb.App.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error {
return hook(e.HttpContext, pb.App, "delete", e.Collection.Name, e.Record, test)
})
}
func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx echo.Context) error {
func hook(ctx echo.Context, app core.App, event, collection string, record *models.Record, test bool) error {
auth, _ := ctx.Get(apis.ContextAuthRecordKey).(*models.Record)
admin, _ := ctx.Get(apis.ContextAdminKey).(*models.Admin)
hook, found, err := findByHookTrigger(dao, collection, event)
if err != nil {
return fmt.Errorf("failed to find hook reaction: %w", err)
if !test {
go mustRunHook(app, collection, event, record, auth, admin)
} else {
mustRunHook(app, collection, event, record, auth, admin)
}
if !found {
return nil
}
return nil
}
func mustRunHook(app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) {
ctx := context.Background()
if err := runHook(ctx, app, collection, event, record, auth, admin); err != nil {
slog.ErrorContext(ctx, fmt.Sprintf("failed to run hook reaction: %v", err))
}
}
func runHook(ctx context.Context, app core.App, collection, event string, record, auth *models.Record, admin *models.Admin) error {
payload, err := json.Marshal(&webhook.Payload{
Action: event,
Collection: collection,
@@ -67,10 +67,19 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec
Admin: admin,
})
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
return fmt.Errorf("failed to marshal webhook payload: %w", err)
}
_, err = action.Run(ctx.Request().Context(), hook.GetString("action"), hook.GetString("actiondata"), string(payload))
hook, found, err := findByHookTrigger(app.Dao(), collection, event)
if err != nil {
return fmt.Errorf("failed to find hook by trigger: %w", err)
}
if !found {
return nil
}
_, err = action.Run(ctx, app, hook.GetString("action"), hook.GetString("actiondata"), string(payload))
if err != nil {
return fmt.Errorf("failed to run hook reaction: %w", err)
}
@@ -81,7 +90,7 @@ func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx ec
func findByHookTrigger(dao *daos.Dao, collection, event string) (*models.Record, bool, error) {
records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "hook"})
if err != nil {
return nil, false, err
return nil, false, fmt.Errorf("failed to find hook reaction: %w", err)
}
if len(records) == 0 {

View File

@@ -15,8 +15,8 @@ type Request struct {
IsBase64Encoded bool `json:"isBase64Encoded"`
}
// isJSON checks if the data is JSON.
func isJSON(data []byte) bool {
// IsJSON checks if the data is JSON.
func IsJSON(data []byte) bool {
var msg json.RawMessage
return json.Unmarshal(data, &msg) == nil

View File

@@ -1,8 +1,14 @@
package webhook
package webhook_test
import "testing"
import (
"testing"
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
)
func Test_isJSON(t *testing.T) {
t.Parallel()
type args struct {
data []byte
}
@@ -29,7 +35,9 @@ func Test_isJSON(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isJSON(tt.args.data); got != tt.want {
t.Parallel()
if got := webhook.IsJSON(tt.args.data); got != tt.want {
t.Errorf("isJSON() = %v, want %v", got, tt.want)
}
})

View File

@@ -9,6 +9,7 @@ import (
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
@@ -26,22 +27,22 @@ type Webhook struct {
const prefix = "/reaction/"
func BindHooks(app core.App) {
app.OnBeforeServe().Add(func(e *core.ServeEvent) error {
e.Router.Any(prefix+"*", handle(e.App.Dao()))
func BindHooks(pb *pocketbase.PocketBase) {
pb.OnBeforeServe().Add(func(e *core.ServeEvent) error {
e.Router.Any(prefix+"*", handle(e.App))
return nil
})
}
func handle(dao *daos.Dao) func(c echo.Context) error {
func handle(app core.App) func(c echo.Context) error {
return func(c echo.Context) error {
record, payload, apiErr := parseRequest(dao, c.Request())
record, payload, apiErr := parseRequest(app.Dao(), c.Request())
if apiErr != nil {
return apiErr
}
output, err := action.Run(c.Request().Context(), record.GetString("action"), record.GetString("actiondata"), string(payload))
output, err := action.Run(c.Request().Context(), app, record.GetString("action"), record.GetString("actiondata"), string(payload))
if err != nil {
return apis.NewApiError(http.StatusInternalServerError, err.Error(), nil)
}
@@ -138,7 +139,7 @@ func writeOutput(c echo.Context, output []byte) error {
}
}
if isJSON(output) {
if IsJSON(output) {
return c.JSON(http.StatusOK, json.RawMessage(output))
}