mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-25 16:33:10 +01:00
refactor: remove pocketbase (#1138)
This commit is contained in:
122
app/reaction/trigger/hook/hook.go
Normal file
122
app/reaction/trigger/hook/hook.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/hook"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
"github.com/SecurityBrewery/catalyst/app/webhook"
|
||||
)
|
||||
|
||||
type Hook struct {
|
||||
Collections []string `json:"collections"`
|
||||
Events []string `json:"events"`
|
||||
}
|
||||
|
||||
func BindHooks(hooks *hook.Hooks, queries *sqlc.Queries, test bool) {
|
||||
hooks.OnRecordAfterCreateRequest.Subscribe(func(ctx context.Context, table string, record any) {
|
||||
bindHook(ctx, queries, database.CreateAction, table, record, test)
|
||||
})
|
||||
hooks.OnRecordAfterUpdateRequest.Subscribe(func(ctx context.Context, table string, record any) {
|
||||
bindHook(ctx, queries, database.UpdateAction, table, record, test)
|
||||
})
|
||||
hooks.OnRecordAfterDeleteRequest.Subscribe(func(ctx context.Context, table string, record any) {
|
||||
bindHook(ctx, queries, database.DeleteAction, table, record, test)
|
||||
})
|
||||
}
|
||||
|
||||
func bindHook(ctx context.Context, queries *sqlc.Queries, event, collection string, record any, test bool) {
|
||||
user, ok := usercontext.UserFromContext(ctx)
|
||||
if !ok {
|
||||
slog.ErrorContext(ctx, "failed to get user from session")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !test {
|
||||
go mustRunHook(context.Background(), queries, collection, event, record, user) //nolint:contextcheck
|
||||
} else {
|
||||
mustRunHook(ctx, queries, collection, event, record, user)
|
||||
}
|
||||
}
|
||||
|
||||
func mustRunHook(ctx context.Context, queries *sqlc.Queries, collection, event string, record any, auth *sqlc.User) {
|
||||
if err := runHook(ctx, queries, collection, event, record, auth); err != nil {
|
||||
slog.ErrorContext(ctx, fmt.Sprintf("failed to run hook reaction: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func runHook(ctx context.Context, queries *sqlc.Queries, collection, event string, record any, auth *sqlc.User) error {
|
||||
payload, err := json.Marshal(&webhook.Payload{
|
||||
Action: event,
|
||||
Collection: collection,
|
||||
Record: record,
|
||||
Auth: auth,
|
||||
Admin: nil,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
hooks, err := findByHookTrigger(ctx, queries, collection, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find hook by trigger: %w", err)
|
||||
}
|
||||
|
||||
if len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := settings.Load(ctx, queries)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load settings: %w", err)
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, hook := range hooks {
|
||||
_, err = action.Run(ctx, settings.Meta.AppURL, queries, hook.Action, hook.Actiondata, payload)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to run hook reaction: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func findByHookTrigger(ctx context.Context, queries *sqlc.Queries, collection, event string) ([]*sqlc.ListReactionsByTriggerRow, error) {
|
||||
reactions, err := database.PaginateItems(ctx, func(ctx context.Context, offset, limit int64) ([]sqlc.ListReactionsByTriggerRow, error) {
|
||||
return queries.ListReactionsByTrigger(ctx, sqlc.ListReactionsByTriggerParams{Trigger: "hook", Limit: limit, Offset: offset})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find hook reaction: %w", err)
|
||||
}
|
||||
|
||||
if len(reactions) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var matchedRecords []*sqlc.ListReactionsByTriggerRow
|
||||
|
||||
for _, reaction := range reactions {
|
||||
var hook Hook
|
||||
if err := json.Unmarshal(reaction.Triggerdata, &hook); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if slices.Contains(hook.Collections, collection) && slices.Contains(hook.Events, event) {
|
||||
matchedRecords = append(matchedRecords, &reaction)
|
||||
}
|
||||
}
|
||||
|
||||
return matchedRecords, nil
|
||||
}
|
||||
15
app/reaction/trigger/webhook/request.go
Normal file
15
app/reaction/trigger/webhook/request.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Headers http.Header `json:"headers"`
|
||||
Query url.Values `json:"query"`
|
||||
Body string `json:"body"`
|
||||
IsBase64Encoded bool `json:"isBase64Encoded"`
|
||||
}
|
||||
162
app/reaction/trigger/webhook/webhook.go
Normal file
162
app/reaction/trigger/webhook/webhook.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action"
|
||||
"github.com/SecurityBrewery/catalyst/app/reaction/action/webhook"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
)
|
||||
|
||||
type Webhook struct {
|
||||
Token string `json:"token"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
const prefix = "/reaction/"
|
||||
|
||||
func BindHooks(router chi.Router, queries *sqlc.Queries) {
|
||||
router.HandleFunc(prefix+"*", handle(queries))
|
||||
}
|
||||
|
||||
func handle(queries *sqlc.Queries) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reaction, payload, status, err := parseRequest(queries, r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), status)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := settings.Load(r.Context(), queries)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load settings: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
output, err := action.Run(r.Context(), settings.Meta.AppURL, queries, reaction.Action, reaction.Actiondata, payload)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := writeOutput(w, output); err != nil {
|
||||
slog.ErrorContext(r.Context(), "failed to write output", "error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRequest(queries *sqlc.Queries, r *http.Request) (*sqlc.ListReactionsByTriggerRow, []byte, int, error) {
|
||||
if !strings.HasPrefix(r.URL.Path, prefix) {
|
||||
return nil, nil, http.StatusNotFound, fmt.Errorf("wrong prefix")
|
||||
}
|
||||
|
||||
reactionName := strings.TrimPrefix(r.URL.Path, prefix)
|
||||
|
||||
reaction, trigger, found, err := findByWebhookTrigger(r.Context(), queries, reactionName)
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusNotFound, err
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, nil, http.StatusNotFound, fmt.Errorf("reaction not found")
|
||||
}
|
||||
|
||||
if trigger.Token != "" {
|
||||
auth := r.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
return nil, nil, http.StatusUnauthorized, fmt.Errorf("missing token")
|
||||
}
|
||||
|
||||
if trigger.Token != strings.TrimPrefix(auth, "Bearer ") {
|
||||
return nil, nil, http.StatusUnauthorized, fmt.Errorf("invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
body, isBase64Encoded := webhook.EncodeBody(r.Body)
|
||||
|
||||
payload, err := json.Marshal(&Request{
|
||||
Method: r.Method,
|
||||
Path: r.URL.EscapedPath(),
|
||||
Headers: r.Header,
|
||||
Query: r.URL.Query(),
|
||||
Body: body,
|
||||
IsBase64Encoded: isBase64Encoded,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return reaction, payload, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func findByWebhookTrigger(ctx context.Context, queries *sqlc.Queries, path string) (*sqlc.ListReactionsByTriggerRow, *Webhook, bool, error) {
|
||||
reactions, err := database.PaginateItems(ctx, func(ctx context.Context, offset, limit int64) ([]sqlc.ListReactionsByTriggerRow, error) {
|
||||
return queries.ListReactionsByTrigger(ctx, sqlc.ListReactionsByTriggerParams{Trigger: "webhook", Limit: limit, Offset: offset})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
if len(reactions) == 0 {
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
for _, reaction := range reactions {
|
||||
var webhook Webhook
|
||||
if err := json.Unmarshal(reaction.Triggerdata, &webhook); err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
if webhook.Path == path {
|
||||
return &reaction, &webhook, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
func writeOutput(w http.ResponseWriter, output []byte) error {
|
||||
var catalystResponse webhook.Response
|
||||
if err := json.Unmarshal(output, &catalystResponse); err == nil && catalystResponse.StatusCode != 0 {
|
||||
for key, values := range catalystResponse.Headers {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if catalystResponse.IsBase64Encoded {
|
||||
output, err = base64.StdEncoding.DecodeString(catalystResponse.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding base64 body: %w", err)
|
||||
}
|
||||
} else {
|
||||
output = []byte(catalystResponse.Body)
|
||||
}
|
||||
}
|
||||
|
||||
if json.Valid(output) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(output)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(output)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user