feat: add reactions (#1074)

This commit is contained in:
Jonas Plum
2024-07-20 06:39:02 +02:00
committed by GitHub
parent 82ad50d228
commit e2c8f1d223
78 changed files with 3270 additions and 257 deletions

44
reaction/action/action.go Normal file
View File

@@ -0,0 +1,44 @@
package action
import (
"context"
"encoding/json"
"fmt"
"github.com/SecurityBrewery/catalyst/reaction/action/python"
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
)
func Run(ctx context.Context, actionName, actionData, payload string) ([]byte, error) {
action, err := decode(actionName, actionData)
if err != nil {
return nil, err
}
return action.Run(ctx, payload)
}
type action interface {
Run(ctx context.Context, payload string) ([]byte, error)
}
func decode(actionName, actionData string) (action, error) {
switch actionName {
case "python":
var reaction python.Python
if err := json.Unmarshal([]byte(actionData), &reaction); err != nil {
return nil, err
}
return &reaction, nil
case "webhook":
var reaction webhook.Webhook
if err := json.Unmarshal([]byte(actionData), &reaction); err != nil {
return nil, err
}
return &reaction, nil
default:
return nil, fmt.Errorf("action %q not found", actionName)
}
}

View File

@@ -0,0 +1,105 @@
package python
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
)
type Python struct {
Bootstrap string `json:"bootstrap"`
Script string `json:"script"`
}
func (a *Python) Run(ctx context.Context, payload string) ([]byte, error) {
tempDir, err := os.MkdirTemp("", "catalyst_action")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempDir)
if b, err := pythonSetup(ctx, tempDir); err != nil {
var ee *exec.ExitError
if errors.As(err, &ee) {
b = append(b, ee.Stderr...)
}
return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b))
}
if b, err := pythonRunBootstrap(ctx, tempDir, a.Bootstrap); err != nil {
var ee *exec.ExitError
if errors.As(err, &ee) {
b = append(b, ee.Stderr...)
}
return nil, fmt.Errorf("failed to run bootstrap, %w: %s", err, string(b))
}
b, err := pythonRunScript(ctx, tempDir, a.Script, payload)
if err != nil {
var ee *exec.ExitError
if errors.As(err, &ee) {
b = append(b, ee.Stderr...)
}
return nil, fmt.Errorf("failed to run script, %w: %s", err, string(b))
}
return b, nil
}
func pythonSetup(ctx context.Context, tempDir string) ([]byte, error) {
pythonPath, err := findExec("python", "python3")
if err != nil {
return nil, fmt.Errorf("python or python3 binary not found, %w", err)
}
// setup virtual environment
return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output()
}
func pythonRunBootstrap(ctx context.Context, tempDir, bootstrap string) ([]byte, error) {
hasBootstrap := len(strings.TrimSpace(bootstrap)) > 0
if !hasBootstrap {
return nil, nil
}
bootstrapPath := tempDir + "/requirements.txt"
if err := os.WriteFile(bootstrapPath, []byte(bootstrap), 0o600); err != nil {
return nil, err
}
// install dependencies
pipPath := tempDir + "/venv/bin/pip"
return exec.CommandContext(ctx, pipPath, "install", "-r", bootstrapPath).Output()
}
func pythonRunScript(ctx context.Context, tempDir, script, payload string) ([]byte, error) {
scriptPath := tempDir + "/script.py"
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
return nil, err
}
pythonPath := tempDir + "/venv/bin/python"
return exec.CommandContext(ctx, pythonPath, scriptPath, payload).Output()
}
func findExec(name ...string) (string, error) {
for _, n := range name {
if p, err := exec.LookPath(n); err == nil {
return p, nil
}
}
return "", errors.New("no executable found")
}

View File

@@ -0,0 +1,20 @@
package webhook
import (
"encoding/base64"
"io"
"unicode/utf8"
)
func EncodeBody(requestBody io.Reader) (string, bool) {
body, err := io.ReadAll(requestBody)
if err != nil {
return "", false
}
if utf8.Valid(body) {
return string(body), false
}
return base64.StdEncoding.EncodeToString(body), true
}

View File

@@ -0,0 +1,49 @@
package webhook
import (
"bytes"
"io"
"testing"
)
func TestEncodeBody(t *testing.T) {
type args struct {
requestBody io.Reader
}
tests := []struct {
name string
args args
want string
want1 bool
}{
{
name: "utf8",
args: args{
requestBody: bytes.NewBufferString("body"),
},
want: "body",
want1: false,
},
{
name: "non-utf8",
args: args{
requestBody: bytes.NewBufferString("body\xe0"),
},
want: "Ym9keeA=",
want1: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := EncodeBody(tt.args.requestBody)
if got != tt.want {
t.Errorf("EncodeBody() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("EncodeBody() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

View File

@@ -0,0 +1,12 @@
package webhook
import (
"net/http"
)
type Response struct {
StatusCode int `json:"statusCode"`
Headers http.Header `json:"headers"`
Body string `json:"body"`
IsBase64Encoded bool `json:"isBase64Encoded"`
}

View File

@@ -0,0 +1,39 @@
package webhook
import (
"context"
"encoding/json"
"net/http"
"strings"
)
type Webhook struct {
Headers map[string]string `json:"headers"`
URL string `json:"url"`
}
func (a *Webhook) Run(ctx context.Context, payload string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.URL, strings.NewReader(payload))
if err != nil {
return nil, err
}
for key, value := range a.Headers {
req.Header.Set(key, value)
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
body, isBase64Encoded := EncodeBody(res.Body)
return json.Marshal(Response{
StatusCode: res.StatusCode,
Headers: res.Header,
Body: body,
IsBase64Encoded: isBase64Encoded,
})
}

13
reaction/trigger.go Normal file
View File

@@ -0,0 +1,13 @@
package reaction
import (
"github.com/pocketbase/pocketbase/core"
"github.com/SecurityBrewery/catalyst/reaction/trigger/hook"
"github.com/SecurityBrewery/catalyst/reaction/trigger/webhook"
)
func BindHooks(app core.App) {
hook.BindHooks(app)
webhook.BindHooks(app)
}

View File

@@ -0,0 +1,103 @@
package hook
import (
"encoding/json"
"fmt"
"slices"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/SecurityBrewery/catalyst/migrations"
"github.com/SecurityBrewery/catalyst/reaction/action"
"github.com/SecurityBrewery/catalyst/webhook"
)
type Hook struct {
Collections []string `json:"collections"`
Events []string `json:"events"`
}
func BindHooks(app core.App) {
app.OnRecordAfterCreateRequest().Add(func(e *core.RecordCreateEvent) error {
if err := hook(app.Dao(), "create", e.Collection.Name, e.Record, e.HttpContext); err != nil {
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
})
app.OnRecordAfterUpdateRequest().Add(func(e *core.RecordUpdateEvent) error {
if err := hook(app.Dao(), "update", e.Collection.Name, e.Record, e.HttpContext); err != nil {
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
})
app.OnRecordAfterDeleteRequest().Add(func(e *core.RecordDeleteEvent) error {
if err := hook(app.Dao(), "delete", e.Collection.Name, e.Record, e.HttpContext); err != nil {
app.Logger().Error("failed to find hook reaction", "error", err.Error())
}
return nil
})
}
func hook(dao *daos.Dao, event, collection string, record *models.Record, ctx echo.Context) error {
auth, _ := ctx.Get(apis.ContextAuthRecordKey).(*models.Record)
admin, _ := ctx.Get(apis.ContextAdminKey).(*models.Admin)
hook, found, err := findByHookTrigger(dao, collection, event)
if err != nil {
return fmt.Errorf("failed to find hook reaction: %w", err)
}
if !found {
return nil
}
payload, err := json.Marshal(&webhook.Payload{
Action: event,
Collection: collection,
Record: record,
Auth: auth,
Admin: admin,
})
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
_, err = action.Run(ctx.Request().Context(), hook.GetString("action"), hook.GetString("actiondata"), string(payload))
if err != nil {
return fmt.Errorf("failed to run hook reaction: %w", err)
}
return nil
}
func findByHookTrigger(dao *daos.Dao, collection, event string) (*models.Record, bool, error) {
records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "hook"})
if err != nil {
return nil, false, err
}
if len(records) == 0 {
return nil, false, nil
}
for _, record := range records {
var hook Hook
if err := json.Unmarshal([]byte(record.GetString("triggerdata")), &hook); err != nil {
return nil, false, err
}
if slices.Contains(hook.Collections, collection) && slices.Contains(hook.Events, event) {
return record, true, nil
}
}
return nil, false, nil
}

View File

@@ -0,0 +1,23 @@
package webhook
import (
"encoding/json"
"net/http"
"net/url"
)
type Request struct {
Method string `json:"method"`
Path string `json:"path"`
Headers http.Header `json:"headers"`
Query url.Values `json:"query"`
Body string `json:"body"`
IsBase64Encoded bool `json:"isBase64Encoded"`
}
// isJSON checks if the data is JSON.
func isJSON(data []byte) bool {
var msg json.RawMessage
return json.Unmarshal(data, &msg) == nil
}

View File

@@ -0,0 +1,37 @@
package webhook
import "testing"
func Test_isJSON(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
args args
want bool
}{
{
name: "valid JSON",
args: args{
data: []byte(`{"key": "value"}`),
},
want: true,
},
{
name: "invalid JSON",
args: args{
data: []byte(`{"key": "value"`),
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isJSON(tt.args.data); got != tt.want {
t.Errorf("isJSON() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,146 @@
package webhook
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/SecurityBrewery/catalyst/migrations"
"github.com/SecurityBrewery/catalyst/reaction/action"
"github.com/SecurityBrewery/catalyst/reaction/action/webhook"
)
type Webhook struct {
Token string `json:"token"`
Path string `json:"path"`
}
const prefix = "/reaction/"
func BindHooks(app core.App) {
app.OnBeforeServe().Add(func(e *core.ServeEvent) error {
e.Router.Any(prefix+"*", handle(e.App.Dao()))
return nil
})
}
func handle(dao *daos.Dao) func(c echo.Context) error {
return func(c echo.Context) error {
record, payload, apiErr := parseRequest(dao, c.Request())
if apiErr != nil {
return apiErr
}
output, err := action.Run(c.Request().Context(), record.GetString("action"), record.GetString("actiondata"), string(payload))
if err != nil {
return apis.NewApiError(http.StatusInternalServerError, err.Error(), nil)
}
return writeOutput(c, output)
}
}
func parseRequest(dao *daos.Dao, r *http.Request) (*models.Record, []byte, *apis.ApiError) {
if !strings.HasPrefix(r.URL.Path, prefix) {
return nil, nil, apis.NewApiError(http.StatusNotFound, "wrong prefix", nil)
}
reactionName := strings.TrimPrefix(r.URL.Path, prefix)
record, trigger, found, err := findByWebhookTrigger(dao, reactionName)
if err != nil {
return nil, nil, apis.NewNotFoundError(err.Error(), nil)
}
if !found {
return nil, nil, apis.NewNotFoundError("reaction not found", nil)
}
if trigger.Token != "" {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
return nil, nil, apis.NewUnauthorizedError("missing token", nil)
}
if trigger.Token != strings.TrimPrefix(auth, "Bearer ") {
return nil, nil, apis.NewUnauthorizedError("invalid token", nil)
}
}
body, isBase64Encoded := webhook.EncodeBody(r.Body)
payload, err := json.Marshal(&Request{
Method: r.Method,
Path: r.URL.EscapedPath(),
Headers: r.Header,
Query: r.URL.Query(),
Body: body,
IsBase64Encoded: isBase64Encoded,
})
if err != nil {
return nil, nil, apis.NewApiError(http.StatusInternalServerError, err.Error(), nil)
}
return record, payload, nil
}
func findByWebhookTrigger(dao *daos.Dao, path string) (*models.Record, *Webhook, bool, error) {
records, err := dao.FindRecordsByExpr(migrations.ReactionCollectionName, dbx.HashExp{"trigger": "webhook"})
if err != nil {
return nil, nil, false, err
}
if len(records) == 0 {
return nil, nil, false, nil
}
for _, record := range records {
var webhook Webhook
if err := json.Unmarshal([]byte(record.GetString("triggerdata")), &webhook); err != nil {
return nil, nil, false, err
}
if webhook.Path == path {
return record, &webhook, true, nil
}
}
return nil, nil, false, nil
}
func writeOutput(c echo.Context, output []byte) error {
var catalystResponse webhook.Response
if err := json.Unmarshal(output, &catalystResponse); err == nil && catalystResponse.StatusCode != 0 {
for key, values := range catalystResponse.Headers {
for _, value := range values {
c.Response().Header().Add(key, value)
}
}
if catalystResponse.IsBase64Encoded {
output, err = base64.StdEncoding.DecodeString(catalystResponse.Body)
if err != nil {
return fmt.Errorf("error decoding base64 body: %w", err)
}
} else {
output = []byte(catalystResponse.Body)
}
}
if isJSON(output) {
return c.JSON(http.StatusOK, json.RawMessage(output))
}
return c.String(http.StatusOK, string(output))
}