Files
catalyst/app/reaction/trigger/webhook/webhook.go
2025-09-02 21:58:08 +02:00

163 lines
4.3 KiB
Go

package webhook
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/SecurityBrewery/catalyst/app/database"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
"github.com/SecurityBrewery/catalyst/app/reaction/action"
"github.com/SecurityBrewery/catalyst/app/reaction/action/webhook"
"github.com/SecurityBrewery/catalyst/app/settings"
)
type Webhook struct {
Token string `json:"token"`
Path string `json:"path"`
}
const prefix = "/reaction/"
func BindHooks(router chi.Router, queries *sqlc.Queries) {
router.HandleFunc(prefix+"*", handle(queries))
}
func handle(queries *sqlc.Queries) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
reaction, payload, status, err := parseRequest(queries, r)
if err != nil {
http.Error(w, err.Error(), status)
return
}
settings, err := settings.Load(r.Context(), queries)
if err != nil {
http.Error(w, "failed to load settings: "+err.Error(), http.StatusInternalServerError)
return
}
output, err := action.Run(r.Context(), settings.Meta.AppURL, queries, reaction.Action, reaction.Actiondata, payload)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if err := writeOutput(w, output); err != nil {
slog.ErrorContext(r.Context(), "failed to write output", "error", err.Error())
}
}
}
func parseRequest(queries *sqlc.Queries, r *http.Request) (*sqlc.ListReactionsByTriggerRow, []byte, int, error) {
if !strings.HasPrefix(r.URL.Path, prefix) {
return nil, nil, http.StatusNotFound, fmt.Errorf("wrong prefix")
}
reactionName := strings.TrimPrefix(r.URL.Path, prefix)
reaction, trigger, found, err := findByWebhookTrigger(r.Context(), queries, reactionName)
if err != nil {
return nil, nil, http.StatusNotFound, err
}
if !found {
return nil, nil, http.StatusNotFound, fmt.Errorf("reaction not found")
}
if trigger.Token != "" {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
return nil, nil, http.StatusUnauthorized, fmt.Errorf("missing token")
}
if trigger.Token != strings.TrimPrefix(auth, "Bearer ") {
return nil, nil, http.StatusUnauthorized, fmt.Errorf("invalid token")
}
}
body, isBase64Encoded := webhook.EncodeBody(r.Body)
payload, err := json.Marshal(&Request{
Method: r.Method,
Path: r.URL.EscapedPath(),
Headers: r.Header,
Query: r.URL.Query(),
Body: body,
IsBase64Encoded: isBase64Encoded,
})
if err != nil {
return nil, nil, http.StatusInternalServerError, err
}
return reaction, payload, http.StatusOK, nil
}
func findByWebhookTrigger(ctx context.Context, queries *sqlc.Queries, path string) (*sqlc.ListReactionsByTriggerRow, *Webhook, bool, error) {
reactions, err := database.PaginateItems(ctx, func(ctx context.Context, offset, limit int64) ([]sqlc.ListReactionsByTriggerRow, error) {
return queries.ListReactionsByTrigger(ctx, sqlc.ListReactionsByTriggerParams{Trigger: "webhook", Limit: limit, Offset: offset})
})
if err != nil {
return nil, nil, false, err
}
if len(reactions) == 0 {
return nil, nil, false, nil
}
for _, reaction := range reactions {
var webhook Webhook
if err := json.Unmarshal(reaction.Triggerdata, &webhook); err != nil {
return nil, nil, false, err
}
if webhook.Path == path {
return &reaction, &webhook, true, nil
}
}
return nil, nil, false, nil
}
func writeOutput(w http.ResponseWriter, output []byte) error {
var catalystResponse webhook.Response
if err := json.Unmarshal(output, &catalystResponse); err == nil && catalystResponse.StatusCode != 0 {
for key, values := range catalystResponse.Headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
if catalystResponse.IsBase64Encoded {
output, err = base64.StdEncoding.DecodeString(catalystResponse.Body)
if err != nil {
return fmt.Errorf("error decoding base64 body: %w", err)
}
} else {
output = []byte(catalystResponse.Body)
}
}
if json.Valid(output) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(output)
} else {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(output)
}
return nil
}