mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-06 07:12:46 +01:00
refactor: remove pocketbase (#1138)
This commit is contained in:
72
app/reaction/action/action.go
Normal file
72
app/reaction/action/action.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/python"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/webhook"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, url string, queries *sqlc.Queries, actionName string, actionData, payload json.RawMessage) ([]byte, error) {
|
||||
action, err := decode(actionName, actionData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a, ok := action.(authenticatedAction); ok {
|
||||
token, err := systemToken(ctx, queries)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get system token: %w", err)
|
||||
}
|
||||
|
||||
a.SetEnv([]string{
|
||||
"CATALYST_APP_URL=" + url,
|
||||
"CATALYST_TOKEN=" + token,
|
||||
})
|
||||
}
|
||||
|
||||
return action.Run(ctx, payload)
|
||||
}
|
||||
|
||||
type action interface {
|
||||
Run(ctx context.Context, payload json.RawMessage) ([]byte, error)
|
||||
}
|
||||
|
||||
type authenticatedAction interface {
|
||||
SetEnv(env []string)
|
||||
}
|
||||
|
||||
func decode(actionName string, actionData json.RawMessage) (action, error) {
|
||||
switch actionName {
|
||||
case "python":
|
||||
var reaction python.Python
|
||||
if err := json.Unmarshal(actionData, &reaction); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &reaction, nil
|
||||
case "webhook":
|
||||
var reaction webhook.Webhook
|
||||
if err := json.Unmarshal(actionData, &reaction); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &reaction, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("action %q not found", actionName)
|
||||
}
|
||||
}
|
||||
|
||||
func systemToken(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||
user, err := queries.SystemUser(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to find system auth record: %w", err)
|
||||
}
|
||||
|
||||
return auth.CreateAccessToken(ctx, &user, auth.All(), time.Hour, queries)
|
||||
}
|
||||
118
app/reaction/action/python/python.go
Normal file
118
app/reaction/action/python/python.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package python
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Python struct {
|
||||
Requirements string `json:"requirements"`
|
||||
Script string `json:"script"`
|
||||
|
||||
env []string
|
||||
}
|
||||
|
||||
func (a *Python) SetEnv(env []string) {
|
||||
a.env = env
|
||||
}
|
||||
|
||||
func (a *Python) Run(ctx context.Context, payload json.RawMessage) ([]byte, error) {
|
||||
tempDir, err := os.MkdirTemp("", "catalyst_action")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
b, err := pythonSetup(ctx, tempDir)
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
b = append(b, ee.Stderr...)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to setup python, %w: %s", err, string(b))
|
||||
}
|
||||
|
||||
b, err = a.pythonInstallRequirements(ctx, tempDir)
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
b = append(b, ee.Stderr...)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to run install requirements, %w: %s", err, string(b))
|
||||
}
|
||||
|
||||
b, err = a.pythonRunScript(ctx, tempDir, string(payload))
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
b = append(b, ee.Stderr...)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to run script, %w: %s", err, string(b))
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func pythonSetup(ctx context.Context, tempDir string) ([]byte, error) {
|
||||
pythonPath, err := findExec("python3", "python")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("python or python3 binary not found, %w", err)
|
||||
}
|
||||
|
||||
// setup virtual environment
|
||||
return exec.CommandContext(ctx, pythonPath, "-m", "venv", tempDir+"/venv").Output()
|
||||
}
|
||||
|
||||
func (a *Python) pythonInstallRequirements(ctx context.Context, tempDir string) ([]byte, error) {
|
||||
hasRequirements := len(strings.TrimSpace(a.Requirements)) > 0
|
||||
|
||||
if !hasRequirements {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
requirementsPath := tempDir + "/requirements.txt"
|
||||
|
||||
if err := os.WriteFile(requirementsPath, []byte(a.Requirements), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// install dependencies
|
||||
pipPath := tempDir + "/venv/bin/pip"
|
||||
|
||||
return exec.CommandContext(ctx, pipPath, "install", "-r", requirementsPath).Output()
|
||||
}
|
||||
|
||||
func (a *Python) pythonRunScript(ctx context.Context, tempDir, payload string) ([]byte, error) {
|
||||
scriptPath := tempDir + "/script.py"
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(a.Script), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pythonPath := tempDir + "/venv/bin/python"
|
||||
|
||||
cmd := exec.CommandContext(ctx, pythonPath, scriptPath, payload)
|
||||
|
||||
cmd.Env = a.env
|
||||
|
||||
return cmd.Output()
|
||||
}
|
||||
|
||||
func findExec(name ...string) (string, error) {
|
||||
for _, n := range name {
|
||||
if p, err := exec.LookPath(n); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no executable found")
|
||||
}
|
||||
104
app/reaction/action/python/python_test.go
Normal file
104
app/reaction/action/python/python_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package python_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/python"
|
||||
)
|
||||
|
||||
func TestPython_Run(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type fields struct {
|
||||
Requirements string
|
||||
Script string
|
||||
}
|
||||
|
||||
type args struct {
|
||||
payload string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
fields: fields{
|
||||
Script: "pass",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: []byte(""),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "hello world",
|
||||
fields: fields{
|
||||
Script: "print('hello world')",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: []byte("hello world\n"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "echo",
|
||||
fields: fields{
|
||||
Script: "import sys; print(sys.argv[1])",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: []byte("test\n"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
fields: fields{
|
||||
Script: "import sys; sys.exit(1)",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "requests",
|
||||
fields: fields{
|
||||
Requirements: "requests",
|
||||
Script: "import requests\nprint(requests.get('https://xkcd.com/2961/info.0.json').text)",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: []byte("{\"month\": \"7\", \"num\": 2961, \"link\": \"\", \"year\": \"2024\", \"news\": \"\", \"safe_title\": \"CrowdStrike\", \"transcript\": \"\", \"alt\": \"We were going to try swordfighting, but all my compiling is on hold.\", \"img\": \"https://imgs.xkcd.com/comics/crowdstrike.png\", \"title\": \"CrowdStrike\", \"day\": \"19\"}\n"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
a := &python.Python{
|
||||
Requirements: tt.fields.Requirements,
|
||||
Script: tt.fields.Script,
|
||||
}
|
||||
got, err := a.Run(ctx, json.RawMessage(tt.args.payload))
|
||||
tt.wantErr(t, err)
|
||||
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
20
app/reaction/action/webhook/payload.go
Normal file
20
app/reaction/action/webhook/payload.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func EncodeBody(requestBody io.Reader) (string, bool) {
|
||||
body, err := io.ReadAll(requestBody)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if utf8.Valid(body) {
|
||||
return string(body), false
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(body), true
|
||||
}
|
||||
55
app/reaction/action/webhook/payload_test.go
Normal file
55
app/reaction/action/webhook/payload_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package webhook_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/webhook"
|
||||
)
|
||||
|
||||
func TestEncodeBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
requestBody io.Reader
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "utf8",
|
||||
args: args{
|
||||
requestBody: bytes.NewBufferString("body"),
|
||||
},
|
||||
want: "body",
|
||||
want1: false,
|
||||
},
|
||||
{
|
||||
name: "non-utf8",
|
||||
args: args{
|
||||
requestBody: bytes.NewBufferString("body\xe0"),
|
||||
},
|
||||
want: "Ym9keeA=",
|
||||
want1: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, got1 := webhook.EncodeBody(tt.args.requestBody)
|
||||
if got != tt.want {
|
||||
t.Errorf("EncodeBody() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("EncodeBody() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
12
app/reaction/action/webhook/response.go
Normal file
12
app/reaction/action/webhook/response.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Headers http.Header `json:"headers"`
|
||||
Body string `json:"body"`
|
||||
IsBase64Encoded bool `json:"isBase64Encoded"`
|
||||
}
|
||||
39
app/reaction/action/webhook/webhook.go
Normal file
39
app/reaction/action/webhook/webhook.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Webhook struct {
|
||||
Headers map[string]string `json:"headers"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func (a *Webhook) Run(ctx context.Context, payload json.RawMessage) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.URL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, value := range a.Headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, isBase64Encoded := EncodeBody(res.Body)
|
||||
|
||||
return json.Marshal(Response{
|
||||
StatusCode: res.StatusCode,
|
||||
Headers: res.Header,
|
||||
Body: body,
|
||||
IsBase64Encoded: isBase64Encoded,
|
||||
})
|
||||
}
|
||||
85
app/reaction/action/webhook/webhook_test.go
Normal file
85
app/reaction/action/webhook/webhook_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package webhook_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/webhook"
|
||||
testing2 "github.com/SecurityBrewery/catalyst/testing"
|
||||
)
|
||||
|
||||
func TestWebhook_Run(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := testing2.NewRecordingServer()
|
||||
|
||||
go http.ListenAndServe("127.0.0.1:12347", server) //nolint:gosec,errcheck
|
||||
|
||||
if err := testing2.WaitForStatus("http://127.0.0.1:12347/health", http.StatusOK, 5*time.Second); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Headers map[string]string
|
||||
URL string
|
||||
}
|
||||
|
||||
type args struct {
|
||||
payload string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want map[string]any
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "",
|
||||
fields: fields{
|
||||
Headers: map[string]string{},
|
||||
URL: "http://127.0.0.1:12347/foo",
|
||||
},
|
||||
args: args{
|
||||
payload: "test",
|
||||
},
|
||||
want: map[string]any{
|
||||
"statusCode": 200,
|
||||
"headers": map[string]any{
|
||||
"Content-Length": []any{"13"},
|
||||
"Content-Type": []any{"application/json; charset=UTF-8"},
|
||||
},
|
||||
"body": "{\"test\":true}",
|
||||
"isBase64Encoded": false,
|
||||
},
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := &webhook.Webhook{
|
||||
Headers: tt.fields.Headers,
|
||||
URL: tt.fields.URL,
|
||||
}
|
||||
got, err := a.Run(t.Context(), json.RawMessage(tt.args.payload))
|
||||
tt.wantErr(t, err)
|
||||
|
||||
want, err := json.Marshal(tt.want)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err = sjson.DeleteBytes(got, "headers.Date")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.JSONEq(t, string(want), string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
113
app/reaction/schedule/schedule.go
Normal file
113
app/reaction/schedule/schedule.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package schedule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
)
|
||||
|
||||
type Scheduler struct {
|
||||
scheduler gocron.Scheduler
|
||||
queries *sqlc.Queries
|
||||
}
|
||||
|
||||
type Schedule struct {
|
||||
Expression string `json:"expression"`
|
||||
}
|
||||
|
||||
func New(ctx context.Context, queries *sqlc.Queries) (*Scheduler, error) {
|
||||
innerScheduler, err := gocron.NewScheduler()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create scheduler: %w", err)
|
||||
}
|
||||
|
||||
scheduler := &Scheduler{
|
||||
scheduler: innerScheduler,
|
||||
queries: queries,
|
||||
}
|
||||
|
||||
if err := scheduler.loadJobs(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load jobs: %w", err)
|
||||
}
|
||||
|
||||
innerScheduler.Start()
|
||||
|
||||
return scheduler, nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) AddReaction(reaction *sqlc.Reaction) error {
|
||||
var schedule Schedule
|
||||
if err := json.Unmarshal(reaction.Triggerdata, &schedule); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal schedule data: %w", err)
|
||||
}
|
||||
|
||||
_, err := s.scheduler.NewJob(
|
||||
gocron.CronJob(schedule.Expression, false),
|
||||
gocron.NewTask(
|
||||
func(ctx context.Context) {
|
||||
settings, err := settings.Load(ctx, s.queries)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to load settings", "error", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_, err = action.Run(ctx, settings.Meta.AppURL, s.queries, reaction.Action, reaction.Actiondata, json.RawMessage("{}"))
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to run schedule reaction", "error", err, "reaction_id", reaction.ID)
|
||||
}
|
||||
},
|
||||
),
|
||||
gocron.WithTags(reaction.ID),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new job for reaction %s: %w", reaction.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) RemoveReaction(id string) {
|
||||
s.scheduler.RemoveByTags(id)
|
||||
}
|
||||
|
||||
func (s *Scheduler) loadJobs(ctx context.Context) error {
|
||||
reactions, err := database.PaginateItems(ctx, func(ctx context.Context, offset, limit int64) ([]sqlc.ListReactionsByTriggerRow, error) {
|
||||
return s.queries.ListReactionsByTrigger(ctx, sqlc.ListReactionsByTriggerParams{Trigger: "schedule", Limit: limit, Offset: offset})
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find schedule reaction: %w", err)
|
||||
}
|
||||
|
||||
if len(reactions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, reaction := range reactions {
|
||||
if err := s.AddReaction(&sqlc.Reaction{
|
||||
Action: reaction.Action,
|
||||
Actiondata: reaction.Actiondata,
|
||||
Created: reaction.Created,
|
||||
ID: reaction.ID,
|
||||
Name: reaction.Name,
|
||||
Trigger: reaction.Trigger,
|
||||
Triggerdata: reaction.Triggerdata,
|
||||
Updated: reaction.Updated,
|
||||
}); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to add reaction %s: %w", reaction.ID, err))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
17
app/reaction/trigger.go
Normal file
17
app/reaction/trigger.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package reaction
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/hook"
|
||||
reactionHook "github.com/SecurityBrewery/catalyst/app/reaction/trigger/hook"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/trigger/webhook"
|
||||
)
|
||||
|
||||
func BindHooks(hooks *hook.Hooks, router chi.Router, queries *sqlc.Queries, test bool) error {
|
||||
reactionHook.BindHooks(hooks, queries, test)
|
||||
webhook.BindHooks(router, queries)
|
||||
|
||||
return nil
|
||||
}
|
||||
122
app/reaction/trigger/hook/hook.go
Normal file
122
app/reaction/trigger/hook/hook.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/hook"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
"github.com/SecurityBrewery/catalyst/app/webhook"
|
||||
)
|
||||
|
||||
type Hook struct {
|
||||
Collections []string `json:"collections"`
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
func BindHooks(hooks *hook.Hooks, queries *sqlc.Queries, test bool) {
|
||||
hooks.OnRecordAfterCreateRequest.Subscribe(func(ctx context.Context, table string, record any) {
|
||||
bindHook(ctx, queries, database.CreateAction, table, record, test)
|
||||
})
|
||||
hooks.OnRecordAfterUpdateRequest.Subscribe(func(ctx context.Context, table string, record any) {
|
||||
bindHook(ctx, queries, database.UpdateAction, table, record, test)
|
||||
})
|
||||
hooks.OnRecordAfterDeleteRequest.Subscribe(func(ctx context.Context, table string, record any) {
|
||||
bindHook(ctx, queries, database.DeleteAction, table, record, test)
|
||||
})
|
||||
}
|
||||
|
||||
func bindHook(ctx context.Context, queries *sqlc.Queries, event, collection string, record any, test bool) {
|
||||
user, ok := usercontext.UserFromContext(ctx)
|
||||
if !ok {
|
||||
slog.ErrorContext(ctx, "failed to get user from session")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !test {
|
||||
go mustRunHook(context.Background(), queries, collection, event, record, user) //nolint:contextcheck
|
||||
} else {
|
||||
mustRunHook(ctx, queries, collection, event, record, user)
|
||||
}
|
||||
}
|
||||
|
||||
func mustRunHook(ctx context.Context, queries *sqlc.Queries, collection, event string, record any, auth *sqlc.User) {
|
||||
if err := runHook(ctx, queries, collection, event, record, auth); err != nil {
|
||||
slog.ErrorContext(ctx, fmt.Sprintf("failed to run hook reaction: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func runHook(ctx context.Context, queries *sqlc.Queries, collection, event string, record any, auth *sqlc.User) error {
|
||||
payload, err := json.Marshal(&webhook.Payload{
|
||||
Action: event,
|
||||
Collection: collection,
|
||||
Record: record,
|
||||
Auth: auth,
|
||||
Admin: nil,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
hooks, err := findByHookTrigger(ctx, queries, collection, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find hook by trigger: %w", err)
|
||||
}
|
||||
|
||||
if len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := settings.Load(ctx, queries)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load settings: %w", err)
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, hook := range hooks {
|
||||
_, err = action.Run(ctx, settings.Meta.AppURL, queries, hook.Action, hook.Actiondata, payload)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to run hook reaction: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func findByHookTrigger(ctx context.Context, queries *sqlc.Queries, collection, event string) ([]*sqlc.ListReactionsByTriggerRow, error) {
|
||||
reactions, err := database.PaginateItems(ctx, func(ctx context.Context, offset, limit int64) ([]sqlc.ListReactionsByTriggerRow, error) {
|
||||
return queries.ListReactionsByTrigger(ctx, sqlc.ListReactionsByTriggerParams{Trigger: "hook", Limit: limit, Offset: offset})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find hook reaction: %w", err)
|
||||
}
|
||||
|
||||
if len(reactions) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var matchedRecords []*sqlc.ListReactionsByTriggerRow
|
||||
|
||||
for _, reaction := range reactions {
|
||||
var hook Hook
|
||||
if err := json.Unmarshal(reaction.Triggerdata, &hook); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if slices.Contains(hook.Collections, collection) && slices.Contains(hook.Events, event) {
|
||||
matchedRecords = append(matchedRecords, &reaction)
|
||||
}
|
||||
}
|
||||
|
||||
return matchedRecords, nil
|
||||
}
|
||||
15
app/reaction/trigger/webhook/request.go
Normal file
15
app/reaction/trigger/webhook/request.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers http.Header `json:"headers"`
|
||||
Query url.Values `json:"query"`
|
||||
Body string `json:"body"`
|
||||
IsBase64Encoded bool `json:"isBase64Encoded"`
|
||||
}
|
||||
162
app/reaction/trigger/webhook/webhook.go
Normal file
162
app/reaction/trigger/webhook/webhook.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/webhook"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
)
|
||||
|
||||
type Webhook struct {
|
||||
Token string `json:"token"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
const prefix = "/reaction/"
|
||||
|
||||
func BindHooks(router chi.Router, queries *sqlc.Queries) {
|
||||
router.HandleFunc(prefix+"*", handle(queries))
|
||||
}
|
||||
|
||||
func handle(queries *sqlc.Queries) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reaction, payload, status, err := parseRequest(queries, r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), status)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := settings.Load(r.Context(), queries)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load settings: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
output, err := action.Run(r.Context(), settings.Meta.AppURL, queries, reaction.Action, reaction.Actiondata, payload)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := writeOutput(w, output); err != nil {
|
||||
slog.ErrorContext(r.Context(), "failed to write output", "error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRequest(queries *sqlc.Queries, r *http.Request) (*sqlc.ListReactionsByTriggerRow, []byte, int, error) {
|
||||
if !strings.HasPrefix(r.URL.Path, prefix) {
|
||||
return nil, nil, http.StatusNotFound, fmt.Errorf("wrong prefix")
|
||||
}
|
||||
|
||||
reactionName := strings.TrimPrefix(r.URL.Path, prefix)
|
||||
|
||||
reaction, trigger, found, err := findByWebhookTrigger(r.Context(), queries, reactionName)
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusNotFound, err
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, nil, http.StatusNotFound, fmt.Errorf("reaction not found")
|
||||
}
|
||||
|
||||
if trigger.Token != "" {
|
||||
auth := r.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
return nil, nil, http.StatusUnauthorized, fmt.Errorf("missing token")
|
||||
}
|
||||
|
||||
if trigger.Token != strings.TrimPrefix(auth, "Bearer ") {
|
||||
return nil, nil, http.StatusUnauthorized, fmt.Errorf("invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
body, isBase64Encoded := webhook.EncodeBody(r.Body)
|
||||
|
||||
payload, err := json.Marshal(&Request{
|
||||
Method: r.Method,
|
||||
Path: r.URL.EscapedPath(),
|
||||
Headers: r.Header,
|
||||
Query: r.URL.Query(),
|
||||
Body: body,
|
||||
IsBase64Encoded: isBase64Encoded,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return reaction, payload, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func findByWebhookTrigger(ctx context.Context, queries *sqlc.Queries, path string) (*sqlc.ListReactionsByTriggerRow, *Webhook, bool, error) {
|
||||
reactions, err := database.PaginateItems(ctx, func(ctx context.Context, offset, limit int64) ([]sqlc.ListReactionsByTriggerRow, error) {
|
||||
return queries.ListReactionsByTrigger(ctx, sqlc.ListReactionsByTriggerParams{Trigger: "webhook", Limit: limit, Offset: offset})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
if len(reactions) == 0 {
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
for _, reaction := range reactions {
|
||||
var webhook Webhook
|
||||
if err := json.Unmarshal(reaction.Triggerdata, &webhook); err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
if webhook.Path == path {
|
||||
return &reaction, &webhook, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
func writeOutput(w http.ResponseWriter, output []byte) error {
|
||||
var catalystResponse webhook.Response
|
||||
if err := json.Unmarshal(output, &catalystResponse); err == nil && catalystResponse.StatusCode != 0 {
|
||||
for key, values := range catalystResponse.Headers {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if catalystResponse.IsBase64Encoded {
|
||||
output, err = base64.StdEncoding.DecodeString(catalystResponse.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding base64 body: %w", err)
|
||||
}
|
||||
} else {
|
||||
output = []byte(catalystResponse.Body)
|
||||
}
|
||||
}
|
||||
|
||||
if json.Valid(output) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(output)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(output)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user