mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2026-04-01 00:12:43 +02:00
20
.github/workflows/ci.yml
vendored
20
.github/workflows/ci.yml
vendored
@@ -9,13 +9,25 @@ env:
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/setup-go@v3
|
||||
with: { go-version: '1.18' }
|
||||
- uses: actions/checkout@v2
|
||||
- run: |
|
||||
mkdir -p ui/dist/img
|
||||
touch ui/dist/index.html ui/dist/favicon.ico ui/dist/manifest.json ui/dist/img/fake.png
|
||||
- uses: golangci/golangci-lint-action@v3
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
env: { GIN_MODE: test }
|
||||
steps:
|
||||
- uses: actions/setup-go@v2
|
||||
with: { go-version: '1.17' }
|
||||
- uses: actions/setup-go@v3
|
||||
with: { go-version: '1.18' }
|
||||
- uses: actions/setup-node@v2
|
||||
with: { node-version: '14' }
|
||||
- uses: actions/checkout@v2
|
||||
@@ -51,8 +63,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [ build-npm, test ]
|
||||
steps:
|
||||
- uses: actions/setup-go@v2
|
||||
with: { go-version: '1.17' }
|
||||
- uses: actions/setup-go@v3
|
||||
with: { go-version: '1.18' }
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/download-artifact@v2
|
||||
with: { name: ui, path: ui/dist }
|
||||
|
||||
116
.golangci.yml
Normal file
116
.golangci.yml
Normal file
@@ -0,0 +1,116 @@
|
||||
run:
|
||||
go: "1.18"
|
||||
skip-dirs:
|
||||
- generated
|
||||
linters:
|
||||
enable:
|
||||
- asciicheck
|
||||
- containedctx
|
||||
- decorder
|
||||
- depguard
|
||||
- dogsled
|
||||
- durationcheck
|
||||
- errchkjson
|
||||
- errname
|
||||
- errorlint
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- forbidigo
|
||||
- forcetypeassert
|
||||
- gci
|
||||
- gocritic
|
||||
- godot
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- goheader
|
||||
- goimports
|
||||
- gomodguard
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- grouper
|
||||
- ifshort
|
||||
- importas
|
||||
- ireturn
|
||||
- misspell
|
||||
- nakedret
|
||||
- nilnil
|
||||
- nlreturn
|
||||
- nolintlint
|
||||
- paralleltest
|
||||
- predeclared
|
||||
- promlinter
|
||||
- revive
|
||||
- tenv
|
||||
- thelper
|
||||
- unconvert
|
||||
- whitespace
|
||||
|
||||
disable:
|
||||
# go 1.18
|
||||
- bodyclose
|
||||
- contextcheck
|
||||
- gosimple
|
||||
- nilerr
|
||||
- noctx
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- stylecheck
|
||||
- tparallel
|
||||
- unparam
|
||||
- unused
|
||||
- wastedassign
|
||||
|
||||
# complexity
|
||||
- cyclop
|
||||
- gocognit
|
||||
- gocyclo
|
||||
- maintidx
|
||||
- nestif
|
||||
|
||||
# disable
|
||||
- dupl
|
||||
- exhaustivestruct
|
||||
- funlen
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
- goconst
|
||||
- godox
|
||||
- goerr113
|
||||
- gomnd
|
||||
- gomoddirectives
|
||||
- lll
|
||||
- makezero
|
||||
- prealloc
|
||||
- structcheck
|
||||
- tagliatelle
|
||||
- testpackage
|
||||
- varnamelen
|
||||
- wrapcheck
|
||||
- wsl
|
||||
linters-settings:
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- prefix(github.com/SecurityBrewery/catalyst)
|
||||
ireturn:
|
||||
allow:
|
||||
- error
|
||||
- context.Context
|
||||
- go-driver.Cursor
|
||||
- go-driver.Collection
|
||||
- go-driver.Database
|
||||
- chi.Router
|
||||
issues:
|
||||
exclude-rules:
|
||||
- path: caql
|
||||
text: "var-naming: don't use underscores"
|
||||
- path: database/user.go
|
||||
text: "G404"
|
||||
linters: [ gosec ]
|
||||
- path: caql/function.go
|
||||
text: "G404"
|
||||
linters: [ gosec ]
|
||||
- path: caql
|
||||
linters: [ forcetypeassert ]
|
||||
48
auth.go
48
auth.go
@@ -2,15 +2,16 @@ package catalyst
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/database"
|
||||
@@ -43,6 +44,7 @@ func (c *AuthConfig) Verifier(ctx context.Context) (*oidc.IDTokenVerifier, error
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}), nil
|
||||
}
|
||||
|
||||
@@ -81,12 +83,14 @@ func bearerAuth(db *database.Database, authHeader string, iss string, config *Au
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
api.JSONErrorStatus(w, http.StatusUnauthorized, errors.New("no bearer token"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
claims, apiError := verifyClaims(r, config, authHeader[7:])
|
||||
if apiError != nil {
|
||||
api.JSONErrorStatus(w, apiError.Status, apiError.Internal)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -100,6 +104,7 @@ func bearerAuth(db *database.Database, authHeader string, iss string, config *Au
|
||||
r, err := setContextClaims(r, db, claims, config)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,6 +121,7 @@ func keyAuth(db *database.Database, keyHeader string) func(next http.Handler) ht
|
||||
key, err := db.UserByHash(r.Context(), h)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not verify private token: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -132,16 +138,19 @@ func sessionAuth(db *database.Database, config *AuthConfig) func(next http.Handl
|
||||
claims, noCookie, err := claimsCookie(r)
|
||||
if err != nil {
|
||||
api.JSONError(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
if noCookie {
|
||||
redirectToLogin(w, r, config.OAuth2)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
r, err = setContextClaims(r, db, claims, config)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +159,7 @@ func sessionAuth(db *database.Database, config *AuthConfig) func(next http.Handl
|
||||
}
|
||||
}
|
||||
|
||||
func setContextClaims(r *http.Request, db *database.Database, claims map[string]interface{}, config *AuthConfig) (*http.Request, error) {
|
||||
func setContextClaims(r *http.Request, db *database.Database, claims map[string]any, config *AuthConfig) (*http.Request, error) {
|
||||
newUser, newSetting, err := mapUserAndSettings(claims, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -182,7 +191,7 @@ func setContextUser(r *http.Request, user *model.UserResponse, hooks *hooks.Hook
|
||||
return busdb.SetContext(r, user)
|
||||
}
|
||||
|
||||
func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*model.UserForm, *model.UserData, error) {
|
||||
func mapUserAndSettings(claims map[string]any, config *AuthConfig) (*model.UserForm, *model.UserData, error) {
|
||||
// handle Bearer tokens
|
||||
// if typ, ok := claims["typ"]; ok && typ == "Bearer" {
|
||||
// return &model.User{
|
||||
@@ -208,8 +217,8 @@ func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*mod
|
||||
name = ""
|
||||
}
|
||||
|
||||
var roles = role.Strings(config.AuthDefaultRoles)
|
||||
if contains(config.AuthAdminUsers, username) {
|
||||
roles := role.Strings(config.AuthDefaultRoles)
|
||||
if slices.Contains(config.AuthAdminUsers, username) {
|
||||
roles = append(roles, role.Admin)
|
||||
}
|
||||
|
||||
@@ -223,20 +232,12 @@ func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*mod
|
||||
}, nil
|
||||
}
|
||||
|
||||
func contains(l []string, s string) bool {
|
||||
for _, e := range l {
|
||||
if e == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) (string, error) {
|
||||
func getString(m map[string]any, key string) (string, error) {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("mapping of %s failed, wrong type (%T)", key, v)
|
||||
}
|
||||
|
||||
@@ -247,12 +248,14 @@ func redirectToLogin(w http.ResponseWriter, r *http.Request, oauth2Config *oauth
|
||||
state, err := state()
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("generating state failed"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
setStateCookie(w, state)
|
||||
|
||||
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -262,11 +265,13 @@ func AuthorizeBlockedUser() func(http.Handler) http.Handler {
|
||||
user, ok := busdb.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if user.Blocked {
|
||||
api.JSONErrorStatus(w, http.StatusForbidden, errors.New("user is blocked"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -281,11 +286,13 @@ func AuthorizeRole(roles []string) func(http.Handler) http.Handler {
|
||||
user, ok := busdb.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !role.UserHasRoles(user, role.FromStrings(roles)) {
|
||||
api.JSONErrorStatus(w, http.StatusForbidden, fmt.Errorf("missing role %s has %s", roles, user.Roles))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -299,17 +306,20 @@ func callback(config *AuthConfig) http.HandlerFunc {
|
||||
state, err := stateCookie(r)
|
||||
if err != nil || state == "" {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state missing"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if state != r.URL.Query().Get("state") {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state mismatch"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := config.OAuth2.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("oauth2 exchange failed: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -317,12 +327,14 @@ func callback(config *AuthConfig) http.HandlerFunc {
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("missing id token"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
claims, apiError := verifyClaims(r, config, rawIDToken)
|
||||
if apiError != nil {
|
||||
api.JSONErrorStatus(w, apiError.Status, apiError.Internal)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -337,10 +349,11 @@ func state() (string, error) {
|
||||
if _, err := rand.Read(rnd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(rnd), nil
|
||||
}
|
||||
|
||||
func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[string]interface{}, *api.HTTPError) {
|
||||
func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[string]any, *api.HTTPError) {
|
||||
verifier, err := config.Verifier(r.Context())
|
||||
if err != nil {
|
||||
return nil, &api.HTTPError{Status: http.StatusUnauthorized, Internal: fmt.Errorf("could not verify: %w", err)}
|
||||
@@ -350,9 +363,10 @@ func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[s
|
||||
return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("could not verify bearer token: %w", err)}
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := authToken.Claims(&claims); err != nil {
|
||||
return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("failed to parse claims: %w", err)}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
@@ -41,7 +41,10 @@ func Backup(catalystStorage *storage.Storage, c *database.Config, writer io.Writ
|
||||
archive := zip.NewWriter(writer)
|
||||
defer archive.Close()
|
||||
|
||||
archive.SetComment(GetVersion())
|
||||
err := archive.SetComment(GetVersion())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// S3
|
||||
if err := backupS3(catalystStorage, archive); err != nil {
|
||||
@@ -86,6 +89,7 @@ func backupS3(catalystStorage *storage.Storage, archive *zip.Writer) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,6 +109,7 @@ func backupArango(c *database.Config, archive *zip.Writer) error {
|
||||
|
||||
func zipDump(dir string, archive *zip.Writer) error {
|
||||
fsys := os.DirFS(dir)
|
||||
|
||||
return fs.WalkDir(fsys, ".", func(p string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -127,6 +132,7 @@ func zipDump(dir string, archive *zip.Writer) error {
|
||||
if _, err := io.Copy(a, f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -144,5 +150,6 @@ func arangodump(dir string, config *database.Config) error {
|
||||
"--server.database", name,
|
||||
}
|
||||
cmd := exec.Command("arangodump", args...)
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func New(c *Config) (*Bus, error) {
|
||||
return &Bus{config: c, client: client}, err
|
||||
}
|
||||
|
||||
func (b *Bus) jsonPublish(msg interface{}, channel, key string) error {
|
||||
func (b *Bus) jsonPublish(msg any, channel, key string) error {
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -65,5 +65,6 @@ func (b *Bus) safeSubscribe(key, channel string, handler func(c *emitter.Client,
|
||||
log.Printf("Recovered %s in channel %s\n", r, channel)
|
||||
}
|
||||
}()
|
||||
|
||||
return b.client.Subscribe(key, channel, handler)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ func (b *Bus) SubscribeDatabaseUpdate(f func(msg *DatabaseUpdateMsg)) error {
|
||||
var msg DatabaseUpdateMsg
|
||||
if err := json.Unmarshal(m.Payload(), &msg); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
go f(&msg)
|
||||
|
||||
@@ -18,7 +18,7 @@ type JobMsg struct {
|
||||
Message *model.Message `json:"message"`
|
||||
}
|
||||
|
||||
func (b *Bus) PublishJob(id, automation string, payload interface{}, context *model.Context, origin *model.Origin) error {
|
||||
func (b *Bus) PublishJob(id, automation string, payload any, context *model.Context, origin *model.Origin) error {
|
||||
return b.jsonPublish(&JobMsg{
|
||||
ID: id,
|
||||
Automation: automation,
|
||||
@@ -35,6 +35,7 @@ func (b *Bus) SubscribeJob(f func(msg *JobMsg)) error {
|
||||
var msg JobMsg
|
||||
if err := json.Unmarshal(m.Payload(), &msg); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
go f(&msg)
|
||||
|
||||
@@ -29,6 +29,7 @@ func (b *Bus) SubscribeRequest(f func(msg *RequestMsg)) error {
|
||||
msg := &RequestMsg{}
|
||||
if err := json.Unmarshal(m.Payload(), msg); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
go f(msg)
|
||||
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
const channelResult = "result"
|
||||
|
||||
type ResultMsg struct {
|
||||
Automation string `json:"automation"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Target *model.Origin `json:"target"`
|
||||
Automation string `json:"automation"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Target *model.Origin `json:"target"`
|
||||
}
|
||||
|
||||
func (b *Bus) PublishResult(automation string, data map[string]interface{}, target *model.Origin) error {
|
||||
func (b *Bus) PublishResult(automation string, data map[string]any, target *model.Origin) error {
|
||||
return b.jsonPublish(&ResultMsg{Automation: automation, Data: data, Target: target}, channelResult, b.config.resultBusKey)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ func (b *Bus) SubscribeResult(f func(msg *ResultMsg)) error {
|
||||
msg := &ResultMsg{}
|
||||
if err := json.Unmarshal(m.Payload(), msg); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
go f(msg)
|
||||
|
||||
@@ -21,7 +21,6 @@ type busService struct {
|
||||
}
|
||||
|
||||
func New(apiURL, apikey, network string, catalystBus *bus.Bus, db *database.Database) error {
|
||||
|
||||
h := &busService{db: db, apiURL: apiURL, apiKey: apikey, network: network, catalystBus: catalystBus}
|
||||
|
||||
if err := catalystBus.SubscribeRequest(h.logRequest); err != nil {
|
||||
@@ -40,6 +39,7 @@ func New(apiURL, apikey, network string, catalystBus *bus.Bus, db *database.Data
|
||||
func busContext() context.Context {
|
||||
// TODO: change roles?
|
||||
bot := &model.UserResponse{ID: "bot", Roles: []string{role.Admin}}
|
||||
|
||||
return busdb.UserContext(context.Background(), bot)
|
||||
}
|
||||
|
||||
|
||||
@@ -59,13 +59,15 @@ func pullImage(ctx context.Context, cli *client.Client, image string) (string, e
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = io.Copy(buf, reader)
|
||||
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
func copyFile(ctx context.Context, cli *client.Client, path string, contentString string, id string) error {
|
||||
tarBuf := &bytes.Buffer{}
|
||||
tw := tar.NewWriter(tarBuf)
|
||||
if err := tw.WriteHeader(&tar.Header{Name: path, Mode: 0755, Size: int64(len(contentString))}); err != nil {
|
||||
header := &tar.Header{Name: path, Mode: 0o755, Size: int64(len(contentString))}
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -90,7 +92,12 @@ func runDocker(ctx context.Context, jobID, containerID string, db *database.Data
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer cli.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{Force: true})
|
||||
defer func(cli *client.Client, ctx context.Context, containerID string, options types.ContainerRemoveOptions) {
|
||||
err := cli.ContainerRemove(ctx, containerID, options)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(cli, ctx, containerID, types.ContainerRemoveOptions{Force: true})
|
||||
|
||||
if err := cli.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil {
|
||||
return nil, nil, err
|
||||
@@ -123,13 +130,16 @@ func streamStdErr(ctx context.Context, cli *client.Client, jobID, containerID st
|
||||
err := scanLines(ctx, jobID, containerLogs, stderrBuf, db)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
if err := containerLogs.Close(); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return stderrBuf, nil
|
||||
}
|
||||
|
||||
@@ -139,24 +149,28 @@ func scanLines(ctx context.Context, jobID string, input io.ReadCloser, output io
|
||||
_, err := stdcopy.StdCopy(w, w, input)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
b := s.Bytes()
|
||||
output.Write(b)
|
||||
output.Write([]byte("\n"))
|
||||
_, _ = output.Write(b)
|
||||
_, _ = output.Write([]byte("\n"))
|
||||
|
||||
if err := db.JobLogAppend(ctx, jobID, string(b)+"\n"); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return s.Err()
|
||||
}
|
||||
|
||||
@@ -172,6 +186,7 @@ func waitForContainer(ctx context.Context, cli *client.Client, containerID strin
|
||||
return fmt.Errorf("container returned status code %d: stderr: %s", exitStatus.StatusCode, stderrBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,17 +19,20 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
|
||||
})
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
automation, err := h.db.AutomationGet(ctx, automationMsg.Automation)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if automation.Script == "" {
|
||||
log.Println("automation is empty")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,11 +42,17 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
|
||||
automationMsg.Message.Secrets["catalyst_apikey"] = h.apiKey
|
||||
automationMsg.Message.Secrets["catalyst_apiurl"] = h.apiURL
|
||||
|
||||
scriptMessage, _ := json.Marshal(automationMsg.Message)
|
||||
scriptMessage, err := json.Marshal(automationMsg.Message)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
containerID, logs, err := createContainer(ctx, automation.Image, automation.Script, string(scriptMessage), h.network)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,18 +64,19 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
|
||||
Status: job.Status,
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
var result map[string]any
|
||||
|
||||
stdout, _, err := runDocker(ctx, automationMsg.ID, containerID, h.db)
|
||||
if err != nil {
|
||||
result = map[string]interface{}{"error": fmt.Sprintf("error running script %s %s", err, string(stdout))}
|
||||
result = map[string]any{"error": fmt.Sprintf("error running script %s %s", err, string(stdout))}
|
||||
} else {
|
||||
var data map[string]interface{}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(stdout, &data); err != nil {
|
||||
result = map[string]interface{}{"error": string(stdout)}
|
||||
result = map[string]any{"error": string(stdout)}
|
||||
} else {
|
||||
result = data
|
||||
}
|
||||
@@ -78,6 +88,7 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
|
||||
|
||||
if err := h.db.JobComplete(ctx, automationMsg.ID, result); err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/generated/caql/parser"
|
||||
)
|
||||
|
||||
var TooComplexError = errors.New("unsupported features for index queries, use advanced search instead")
|
||||
var ErrTooComplex = errors.New("unsupported features for index queries, use advanced search instead")
|
||||
|
||||
type bleveBuilder struct {
|
||||
*parser.BaseCAQLParserListener
|
||||
@@ -35,8 +35,9 @@ func (s *bleveBuilder) pop() (n string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *bleveBuilder) binaryPop() (interface{}, interface{}) {
|
||||
func (s *bleveBuilder) binaryPop() (any, any) {
|
||||
right, left := s.pop(), s.pop()
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
@@ -48,9 +49,7 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
case ctx.Reference() != nil:
|
||||
// pass
|
||||
case ctx.Operator_unary() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_PLUS() != nil:
|
||||
fallthrough
|
||||
case ctx.T_MINUS() != nil:
|
||||
@@ -60,13 +59,9 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
case ctx.T_DIV() != nil:
|
||||
fallthrough
|
||||
case ctx.T_MOD() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_RANGE() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s:<%s", left, right))
|
||||
@@ -79,64 +74,46 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s:>=%s", left, right))
|
||||
|
||||
case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s:%s", left, right))
|
||||
case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("-%s:%s", left, right))
|
||||
|
||||
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
|
||||
fallthrough
|
||||
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
|
||||
fallthrough
|
||||
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
fallthrough
|
||||
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
fallthrough
|
||||
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_LIKE() != nil:
|
||||
s.err = errors.New("index queries are like queries by default")
|
||||
return
|
||||
|
||||
case ctx.T_REGEX_MATCH() != nil:
|
||||
left, right := s.binaryPop()
|
||||
if ctx.T_NOT() != nil {
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
} else {
|
||||
s.push(fmt.Sprintf("%s:/%s/", left, right))
|
||||
}
|
||||
case ctx.T_REGEX_NON_MATCH() != nil:
|
||||
s.err = errors.New("index query cannot contain regex non matches, use advanced search instead")
|
||||
return
|
||||
|
||||
case ctx.T_AND() != nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s %s", left, right))
|
||||
case ctx.T_OR() != nil:
|
||||
s.err = errors.New("index query cannot contain OR, use advanced search instead")
|
||||
return
|
||||
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
|
||||
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
|
||||
return
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
|
||||
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
|
||||
return
|
||||
|
||||
default:
|
||||
panic("unknown expression")
|
||||
}
|
||||
@@ -152,17 +129,13 @@ func (s *bleveBuilder) ExitReference(ctx *parser.ReferenceContext) {
|
||||
case ctx.T_STRING() != nil:
|
||||
s.push(ctx.T_STRING().GetText())
|
||||
case ctx.Compound_value() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
case ctx.Function_call() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_OPEN() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_ARRAY_OPEN() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package caql
|
||||
package caql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/caql"
|
||||
)
|
||||
|
||||
func TestBleveBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saql string
|
||||
@@ -18,15 +22,20 @@ func TestBleveBuilder(t *testing.T) {
|
||||
{name: "Search 4", saql: `title == 'malware' AND 'wannacry'`, wantBleve: `title:"malware" "wannacry"`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
parser := &Parser{}
|
||||
tt := tt
|
||||
|
||||
parser := &caql.Parser{}
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expr, err := parser.Parse(tt.saql)
|
||||
if (err != nil) != tt.wantParseErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
if expr != nil {
|
||||
t.Error(expr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -37,6 +46,7 @@ func TestBleveBuilder(t *testing.T) {
|
||||
if (err != nil) != tt.wantRebuildErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/generated/caql/parser"
|
||||
)
|
||||
|
||||
@@ -40,6 +42,7 @@ func (s *aqlBuilder) pop() (n string) {
|
||||
|
||||
func (s *aqlBuilder) binaryPop() (string, string) {
|
||||
right, left := s.pop(), s.pop()
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
@@ -181,8 +184,10 @@ func (s *aqlBuilder) toBoolString(v string) string {
|
||||
if err != nil {
|
||||
panic("invalid search " + err.Error())
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`d._key IN ["%s"]`, strings.Join(ids, `","`))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -246,7 +251,7 @@ func (s *aqlBuilder) ExitFunction_call(ctx *parser.Function_callContext) {
|
||||
}
|
||||
parameter := strings.Join(array, ", ")
|
||||
|
||||
if !stringSliceContains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
|
||||
if !slices.Contains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
|
||||
panic("unknown function")
|
||||
}
|
||||
|
||||
|
||||
136
caql/function.go
136
caql/function.go
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
switch strings.ToUpper(ctx.T_STRING().GetText()) {
|
||||
|
||||
default:
|
||||
s.appendErrors(errors.New("unknown function"))
|
||||
|
||||
@@ -26,8 +25,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
if len(ctx.AllExpression()) == 3 {
|
||||
u = s.pop().(bool)
|
||||
}
|
||||
seen := map[interface{}]bool{}
|
||||
values, anyArray := s.pop().([]interface{}), s.pop().([]interface{})
|
||||
seen := map[any]bool{}
|
||||
values, anyArray := s.pop().([]any), s.pop().([]any)
|
||||
|
||||
if u {
|
||||
for _, e := range anyArray {
|
||||
@@ -45,18 +44,18 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(anyArray)
|
||||
case "COUNT_DISTINCT", "COUNT_UNIQUE":
|
||||
count := 0
|
||||
seen := map[interface{}]bool{}
|
||||
array := s.pop().([]interface{})
|
||||
seen := map[any]bool{}
|
||||
array := s.pop().([]any)
|
||||
for _, e := range array {
|
||||
_, ok := seen[e]
|
||||
if !ok {
|
||||
seen[e] = true
|
||||
count += 1
|
||||
count++
|
||||
}
|
||||
}
|
||||
s.push(float64(count))
|
||||
case "FIRST":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if len(array) == 0 {
|
||||
s.push(nil)
|
||||
} else {
|
||||
@@ -65,16 +64,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "FLATTEN":
|
||||
// case "INTERLEAVE":
|
||||
case "INTERSECTION":
|
||||
iset := New(s.pop().([]interface{})...)
|
||||
iset := NewSet(s.pop().([]any)...)
|
||||
|
||||
for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
iset = iset.Intersection(New(s.pop().([]interface{})...))
|
||||
iset = iset.Intersection(NewSet(s.pop().([]any)...))
|
||||
}
|
||||
|
||||
s.push(iset.Values())
|
||||
// case "JACCARD":
|
||||
case "LAST":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if len(array) == 0 {
|
||||
s.push(nil)
|
||||
} else {
|
||||
@@ -94,9 +93,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(float64(len(fmt.Sprint(v))))
|
||||
case string:
|
||||
s.push(float64(utf8.RuneCountInString(v)))
|
||||
case []interface{}:
|
||||
case []any:
|
||||
s.push(float64(len(v)))
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
s.push(float64(len(v)))
|
||||
default:
|
||||
panic("unknown type")
|
||||
@@ -104,7 +103,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "MINUS":
|
||||
var sets []*Set
|
||||
for i := 0; i < len(ctx.AllExpression()); i++ {
|
||||
sets = append(sets, New(s.pop().([]interface{})...))
|
||||
sets = append(sets, NewSet(s.pop().([]any)...))
|
||||
}
|
||||
|
||||
iset := sets[len(sets)-1]
|
||||
@@ -116,7 +115,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(iset.Values())
|
||||
case "NTH":
|
||||
pos := s.pop().(float64)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if int(pos) >= len(array) || pos < 0 {
|
||||
s.push(nil)
|
||||
} else {
|
||||
@@ -124,16 +123,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
// case "OUTERSECTION":
|
||||
// array := s.pop().([]interface{})
|
||||
// union := New(array...)
|
||||
// intersection := New(s.pop().([]interface{})...)
|
||||
// union := NewSet(array...)
|
||||
// intersection := NewSet(s.pop().([]interface{})...)
|
||||
// for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
// array = s.pop().([]interface{})
|
||||
// union = union.Union(New(array...))
|
||||
// intersection = intersection.Intersection(New(array...))
|
||||
// union = union.Union(NewSet(array...))
|
||||
// intersection = intersection.Intersection(NewSet(array...))
|
||||
// }
|
||||
// s.push(union.Minus(intersection).Values())
|
||||
case "POP":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
s.push(array[:len(array)-1])
|
||||
case "POSITION", "CONTAINS_ARRAY":
|
||||
returnIndex := false
|
||||
@@ -141,7 +140,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
returnIndex = s.pop().(bool)
|
||||
}
|
||||
search := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
for idx, e := range array {
|
||||
if e == search {
|
||||
@@ -164,7 +163,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
u = s.pop().(bool)
|
||||
}
|
||||
element := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
if u && contains(array, element) {
|
||||
s.push(array)
|
||||
@@ -173,13 +172,13 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
case "REMOVE_NTH":
|
||||
position := s.pop().(float64)
|
||||
anyArray := s.pop().([]interface{})
|
||||
anyArray := s.pop().([]any)
|
||||
|
||||
if position < 0 {
|
||||
position = float64(len(anyArray) + int(position))
|
||||
}
|
||||
|
||||
result := []interface{}{}
|
||||
result := []any{}
|
||||
for idx, e := range anyArray {
|
||||
if idx != int(position) {
|
||||
result = append(result, e)
|
||||
@@ -193,7 +192,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
replaceValue := s.pop().(string)
|
||||
position := s.pop().(float64)
|
||||
anyArray := s.pop().([]interface{})
|
||||
anyArray := s.pop().([]any)
|
||||
|
||||
if position < 0 {
|
||||
position = float64(len(anyArray) + int(position))
|
||||
@@ -224,8 +223,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
limit = s.pop().(float64)
|
||||
}
|
||||
value := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
result := []interface{}{}
|
||||
array := s.pop().([]any)
|
||||
result := []any{}
|
||||
for idx, e := range array {
|
||||
if e != value || float64(idx) > limit {
|
||||
result = append(result, e)
|
||||
@@ -233,9 +232,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(result)
|
||||
case "REMOVE_VALUES":
|
||||
values := s.pop().([]interface{})
|
||||
array := s.pop().([]interface{})
|
||||
result := []interface{}{}
|
||||
values := s.pop().([]any)
|
||||
array := s.pop().([]any)
|
||||
result := []any{}
|
||||
for _, e := range array {
|
||||
if !contains(values, e) {
|
||||
result = append(result, e)
|
||||
@@ -243,14 +242,14 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(result)
|
||||
case "REVERSE":
|
||||
array := s.pop().([]interface{})
|
||||
var reverse []interface{}
|
||||
array := s.pop().([]any)
|
||||
var reverse []any
|
||||
for _, e := range array {
|
||||
reverse = append([]interface{}{e}, reverse...)
|
||||
reverse = append([]any{e}, reverse...)
|
||||
}
|
||||
s.push(reverse)
|
||||
case "SHIFT":
|
||||
s.push(s.pop().([]interface{})[1:])
|
||||
s.push(s.pop().([]any)[1:])
|
||||
case "SLICE":
|
||||
length := float64(-1)
|
||||
full := true
|
||||
@@ -259,7 +258,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
full = false
|
||||
}
|
||||
start := int64(s.pop().(float64))
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
if start < 0 {
|
||||
start = int64(len(array)) + start
|
||||
@@ -276,43 +275,43 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(array[start:end])
|
||||
case "SORTED":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
|
||||
s.push(array)
|
||||
case "SORTED_UNIQUE":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
|
||||
s.push(unique(array))
|
||||
case "UNION":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
array = append(array, s.pop().([]interface{})...)
|
||||
array = append(array, s.pop().([]any)...)
|
||||
}
|
||||
|
||||
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
|
||||
s.push(array)
|
||||
case "UNION_DISTINCT":
|
||||
iset := New(s.pop().([]interface{})...)
|
||||
iset := NewSet(s.pop().([]any)...)
|
||||
|
||||
for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
iset = iset.Union(New(s.pop().([]interface{})...))
|
||||
iset = iset.Union(NewSet(s.pop().([]any)...))
|
||||
}
|
||||
|
||||
s.push(unique(iset.Values()))
|
||||
case "UNIQUE":
|
||||
s.push(unique(s.pop().([]interface{})))
|
||||
s.push(unique(s.pop().([]any)))
|
||||
case "UNSHIFT":
|
||||
u := false
|
||||
if len(ctx.AllExpression()) == 3 {
|
||||
u = s.pop().(bool)
|
||||
}
|
||||
element := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if u && contains(array, element) {
|
||||
s.push(array)
|
||||
} else {
|
||||
s.push(append([]interface{}{element}, array...))
|
||||
s.push(append([]any{element}, array...))
|
||||
}
|
||||
|
||||
// Bit https://www.arangodb.com/docs/stable/aql/functions-bit.html
|
||||
@@ -367,8 +366,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
if len(ctx.AllExpression()) >= 2 {
|
||||
removeInternal = s.pop().(bool)
|
||||
}
|
||||
var keys []interface{}
|
||||
for k := range s.pop().(map[string]interface{}) {
|
||||
var keys []any
|
||||
for k := range s.pop().(map[string]any) {
|
||||
isInternalKey := strings.HasPrefix(k, "_")
|
||||
if !removeInternal || !isInternalKey {
|
||||
keys = append(keys, k)
|
||||
@@ -379,20 +378,20 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "COUNT":
|
||||
case "HAS":
|
||||
right, left := s.pop(), s.pop()
|
||||
_, ok := left.(map[string]interface{})[right.(string)]
|
||||
_, ok := left.(map[string]any)[right.(string)]
|
||||
s.push(ok)
|
||||
// case "KEEP":
|
||||
// case "LENGTH":
|
||||
// case "MATCHES":
|
||||
case "MERGE":
|
||||
var docs []map[string]interface{}
|
||||
var docs []map[string]any
|
||||
if len(ctx.AllExpression()) == 1 {
|
||||
for _, doc := range s.pop().([]interface{}) {
|
||||
docs = append([]map[string]interface{}{doc.(map[string]interface{})}, docs...)
|
||||
for _, doc := range s.pop().([]any) {
|
||||
docs = append([]map[string]any{doc.(map[string]any)}, docs...)
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < len(ctx.AllExpression()); i++ {
|
||||
docs = append(docs, s.pop().(map[string]interface{}))
|
||||
docs = append(docs, s.pop().(map[string]any))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,9 +403,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(doc)
|
||||
case "MERGE_RECURSIVE":
|
||||
var doc map[string]interface{}
|
||||
var doc map[string]any
|
||||
for i := 0; i < len(ctx.AllExpression()); i++ {
|
||||
err := mergo.Merge(&doc, s.pop().(map[string]interface{}))
|
||||
err := mergo.Merge(&doc, s.pop().(map[string]any))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -421,8 +420,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
if len(ctx.AllExpression()) == 2 {
|
||||
removeInternal = s.pop().(bool)
|
||||
}
|
||||
var values []interface{}
|
||||
for k, v := range s.pop().(map[string]interface{}) {
|
||||
var values []any
|
||||
for k, v := range s.pop().(map[string]any) {
|
||||
isInternalKey := strings.HasPrefix(k, "_")
|
||||
if !removeInternal || !isInternalKey {
|
||||
values = append(values, v)
|
||||
@@ -458,10 +457,10 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "AVERAGE", "AVG":
|
||||
count := 0
|
||||
sum := float64(0)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
count += 1
|
||||
count++
|
||||
sum += toNumber(element)
|
||||
}
|
||||
}
|
||||
@@ -506,7 +505,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "MAX":
|
||||
var set bool
|
||||
var max float64
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
if !set || toNumber(element) > max {
|
||||
@@ -521,7 +520,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(nil)
|
||||
}
|
||||
case "MEDIAN":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
var numbers []float64
|
||||
for _, element := range array {
|
||||
if f, ok := element.(float64); ok {
|
||||
@@ -544,7 +543,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "MIN":
|
||||
var set bool
|
||||
var min float64
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
if !set || toNumber(element) < min {
|
||||
@@ -566,7 +565,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(math.Pow(left.(float64), right.(float64)))
|
||||
case "PRODUCT":
|
||||
product := float64(1)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
product *= toNumber(element)
|
||||
@@ -578,7 +577,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "RAND":
|
||||
s.push(rand.Float64())
|
||||
case "RANGE":
|
||||
var array []interface{}
|
||||
var array []any
|
||||
var start, end, step float64
|
||||
if len(ctx.AllExpression()) == 2 {
|
||||
right, left := s.pop(), s.pop()
|
||||
@@ -612,7 +611,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "STDDEV":
|
||||
case "SUM":
|
||||
sum := float64(0)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
sum += toNumber(element)
|
||||
}
|
||||
@@ -691,7 +690,6 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "IS_IPV4":
|
||||
// case "IS_KEY":
|
||||
// case "TYPENAME":
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -705,6 +703,7 @@ func unique(array []interface{}) []interface{} {
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
@@ -714,15 +713,7 @@ func contains(values []interface{}, e interface{}) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringSliceContains(values []string, e string) bool {
|
||||
for _, v := range values {
|
||||
if e == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -747,4 +738,5 @@ var functionNames = []string{
|
||||
"REGEX_REPLACE", "REVERSE", "RIGHT", "RTRIM", "SHA1", "SHA512", "SOUNDEX", "SPLIT", "STARTS_WITH", "SUBSTITUTE",
|
||||
"SUBSTRING", "TOKENS", "TO_BASE64", "TO_HEX", "TRIM", "UPPER", "UUID", "TO_BOOL", "TO_NUMBER", "TO_STRING",
|
||||
"TO_ARRAY", "TO_LIST", "IS_NULL", "IS_BOOL", "IS_NUMBER", "IS_STRING", "IS_ARRAY", "IS_LIST", "IS_OBJECT",
|
||||
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME"}
|
||||
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME",
|
||||
}
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
package caql
|
||||
package caql_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/caql"
|
||||
)
|
||||
|
||||
func TestFunctions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saql string
|
||||
wantRebuild string
|
||||
wantValue interface{}
|
||||
wantValue any
|
||||
wantParseErr bool
|
||||
wantRebuildErr bool
|
||||
wantEvalErr bool
|
||||
@@ -266,13 +270,13 @@ func TestFunctions(t *testing.T) {
|
||||
{name: "RADIANS", saql: `RADIANS(0)`, wantRebuild: `RADIANS(0)`, wantValue: 0},
|
||||
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.3503170117504508},
|
||||
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.6138226173882478},
|
||||
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []interface{}{float64(1), float64(2), float64(3), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []interface{}{float64(1), float64(3)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []interface{}{float64(1), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []interface{}{float64(1), float64(2)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []interface{}{1.5, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []interface{}{1.5, 2.0, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []interface{}{-0.75, -0.25, 0.25, 0.75}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []any{float64(1), float64(2), float64(3), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []any{float64(1), float64(3)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []any{float64(1), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []any{float64(1), float64(2)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []any{1.5, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []any{1.5, 2.0, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []any{-0.75, -0.25, 0.25, 0.75}},
|
||||
{name: "ROUND", saql: `ROUND(2.49)`, wantRebuild: `ROUND(2.49)`, wantValue: 2},
|
||||
{name: "ROUND", saql: `ROUND(2.50)`, wantRebuild: `ROUND(2.50)`, wantValue: 3},
|
||||
{name: "ROUND", saql: `ROUND(-2.50)`, wantRebuild: `ROUND(-2.50)`, wantValue: -2},
|
||||
@@ -299,15 +303,20 @@ func TestFunctions(t *testing.T) {
|
||||
{name: "Function Error 3", saql: `ABS("abs")`, wantRebuild: `ABS("abs")`, wantEvalErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
parser := &Parser{}
|
||||
tt := tt
|
||||
|
||||
parser := &caql.Parser{}
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expr, err := parser.Parse(tt.saql)
|
||||
if (err != nil) != tt.wantParseErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
if expr != nil {
|
||||
t.Error(expr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -318,6 +327,7 @@ func TestFunctions(t *testing.T) {
|
||||
if (err != nil) != tt.wantRebuildErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -327,18 +337,19 @@ func TestFunctions(t *testing.T) {
|
||||
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
|
||||
}
|
||||
|
||||
var myJson map[string]interface{}
|
||||
var myJSON map[string]any
|
||||
if tt.values != "" {
|
||||
err = json.Unmarshal([]byte(tt.values), &myJson)
|
||||
err = json.Unmarshal([]byte(tt.values), &myJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
value, err := expr.Eval(myJson)
|
||||
value, err := expr.Eval(myJSON)
|
||||
if (err != nil) != tt.wantEvalErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -367,14 +378,15 @@ func TestFunctions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func jsonParse(s string) interface{} {
|
||||
func jsonParse(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var j interface{}
|
||||
var j any
|
||||
err := json.Unmarshal([]byte(s), &j)
|
||||
if err != nil {
|
||||
panic(s + err.Error())
|
||||
}
|
||||
|
||||
return j
|
||||
}
|
||||
|
||||
@@ -10,22 +10,23 @@ import (
|
||||
|
||||
type aqlInterpreter struct {
|
||||
*parser.BaseCAQLParserListener
|
||||
values map[string]interface{}
|
||||
stack []interface{}
|
||||
values map[string]any
|
||||
stack []any
|
||||
errs []error
|
||||
}
|
||||
|
||||
// push is a helper function for pushing new node to the listener Stack.
|
||||
func (s *aqlInterpreter) push(i interface{}) {
|
||||
func (s *aqlInterpreter) push(i any) {
|
||||
s.stack = append(s.stack, i)
|
||||
}
|
||||
|
||||
// pop is a helper function for poping a node from the listener Stack.
|
||||
func (s *aqlInterpreter) pop() (n interface{}) {
|
||||
func (s *aqlInterpreter) pop() (n any) {
|
||||
// Check that we have nodes in the stack.
|
||||
size := len(s.stack)
|
||||
if size < 1 {
|
||||
s.appendErrors(ErrStack)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -35,8 +36,9 @@ func (s *aqlInterpreter) pop() (n interface{}) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *aqlInterpreter) binaryPop() (interface{}, interface{}) {
|
||||
func (s *aqlInterpreter) binaryPop() (any, any) {
|
||||
right, left := s.pop(), s.pop()
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
@@ -54,17 +56,14 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
s.push(plus(s.binaryPop()))
|
||||
case ctx.T_MINUS() != nil:
|
||||
s.push(minus(s.binaryPop()))
|
||||
|
||||
case ctx.T_TIMES() != nil:
|
||||
s.push(times(s.binaryPop()))
|
||||
case ctx.T_DIV() != nil:
|
||||
s.push(div(s.binaryPop()))
|
||||
case ctx.T_MOD() != nil:
|
||||
s.push(mod(s.binaryPop()))
|
||||
|
||||
case ctx.T_RANGE() != nil:
|
||||
s.push(aqlrange(s.binaryPop()))
|
||||
|
||||
case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(lt(s.binaryPop()))
|
||||
case ctx.T_GT() != nil && ctx.GetEq_op() == nil:
|
||||
@@ -73,35 +72,30 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
s.push(le(s.binaryPop()))
|
||||
case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(ge(s.binaryPop()))
|
||||
|
||||
case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(maybeNot(ctx, in(s.binaryPop())))
|
||||
|
||||
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(eq(s.binaryPop()))
|
||||
case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(ne(s.binaryPop()))
|
||||
|
||||
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(all(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
s.push(all(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(any(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
s.push(anyElement(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(none(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
|
||||
s.push(none(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(all(left.([]interface{}), in, right))
|
||||
s.push(all(left.([]any), in, right))
|
||||
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(any(left.([]interface{}), in, right))
|
||||
s.push(anyElement(left.([]any), in, right))
|
||||
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(none(left.([]interface{}), in, right))
|
||||
|
||||
s.push(none(left.([]any), in, right))
|
||||
case ctx.T_LIKE() != nil:
|
||||
m, err := like(s.binaryPop())
|
||||
s.appendErrors(err)
|
||||
@@ -114,21 +108,18 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
m, err := regexNonMatch(s.binaryPop())
|
||||
s.appendErrors(err)
|
||||
s.push(maybeNot(ctx, m))
|
||||
|
||||
case ctx.T_AND() != nil:
|
||||
s.push(and(s.binaryPop()))
|
||||
case ctx.T_OR() != nil:
|
||||
s.push(or(s.binaryPop()))
|
||||
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
|
||||
right, middle, left := s.pop(), s.pop(), s.pop()
|
||||
s.push(ternary(left, middle, right))
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(ternary(left, nil, right))
|
||||
|
||||
default:
|
||||
panic("unkown expression")
|
||||
panic("unknown expression")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,7 +150,7 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
|
||||
case ctx.DOT() != nil:
|
||||
reference := s.pop()
|
||||
|
||||
s.push(reference.(map[string]interface{})[ctx.T_STRING().GetText()])
|
||||
s.push(reference.(map[string]any)[ctx.T_STRING().GetText()])
|
||||
case ctx.T_STRING() != nil:
|
||||
s.push(s.getVar(ctx.T_STRING().GetText()))
|
||||
case ctx.Compound_value() != nil:
|
||||
@@ -175,14 +166,15 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
|
||||
if f, ok := key.(float64); ok {
|
||||
index := int(f)
|
||||
if index < 0 {
|
||||
index = len(reference.([]interface{})) + index
|
||||
index = len(reference.([]any)) + index
|
||||
}
|
||||
|
||||
s.push(reference.([]interface{})[index])
|
||||
s.push(reference.([]any)[index])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.push(reference.(map[string]interface{})[key.(string)])
|
||||
s.push(reference.(map[string]any)[key.(string)])
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
|
||||
}
|
||||
@@ -239,17 +231,17 @@ func (s *aqlInterpreter) ExitValue_literal(ctx *parser.Value_literalContext) {
|
||||
|
||||
// ExitArray is called when production array is exited.
|
||||
func (s *aqlInterpreter) ExitArray(ctx *parser.ArrayContext) {
|
||||
array := []interface{}{}
|
||||
array := []any{}
|
||||
for range ctx.AllExpression() {
|
||||
// prepend element
|
||||
array = append([]interface{}{s.pop()}, array...)
|
||||
array = append([]any{s.pop()}, array...)
|
||||
}
|
||||
s.push(array)
|
||||
}
|
||||
|
||||
// ExitObject is called when production object is exited.
|
||||
func (s *aqlInterpreter) ExitObject(ctx *parser.ObjectContext) {
|
||||
object := map[string]interface{}{}
|
||||
object := map[string]any{}
|
||||
for range ctx.AllObject_element() {
|
||||
key, value := s.pop(), s.pop()
|
||||
|
||||
@@ -290,7 +282,7 @@ func (s *aqlInterpreter) ExitObject_element_name(ctx *parser.Object_element_name
|
||||
}
|
||||
}
|
||||
|
||||
func (s *aqlInterpreter) getVar(identifier string) interface{} {
|
||||
func (s *aqlInterpreter) getVar(identifier string) any {
|
||||
v, ok := s.values[identifier]
|
||||
if !ok {
|
||||
s.appendErrors(ErrUndefined)
|
||||
@@ -303,10 +295,11 @@ func maybeNot(ctx *parser.ExpressionContext, m bool) bool {
|
||||
if ctx.T_NOT() != nil {
|
||||
return !m
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func getOp(tokenType int) func(left, right interface{}) bool {
|
||||
func getOp(tokenType int) func(left, right any) bool {
|
||||
switch tokenType {
|
||||
case parser.CAQLLexerT_EQ:
|
||||
return eq
|
||||
@@ -323,33 +316,36 @@ func getOp(tokenType int) func(left, right interface{}) bool {
|
||||
case parser.CAQLLexerT_IN:
|
||||
return in
|
||||
default:
|
||||
panic("unkown token type")
|
||||
panic("unknown token type")
|
||||
}
|
||||
}
|
||||
|
||||
func all(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
|
||||
func all(slice []any, op func(any, any) bool, expr any) bool {
|
||||
for _, e := range slice {
|
||||
if !op(e, expr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func any(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
|
||||
func anyElement(slice []any, op func(any, any) bool, expr any) bool {
|
||||
for _, e := range slice {
|
||||
if op(e, expr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func none(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
|
||||
func none(slice []any, op func(any, any) bool, expr any) bool {
|
||||
for _, e := range slice {
|
||||
if op(e, expr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -10,21 +10,23 @@ import (
|
||||
|
||||
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
|
||||
|
||||
func or(left, right interface{}) interface{} {
|
||||
func or(left, right any) any {
|
||||
if toBool(left) {
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
func and(left, right interface{}) interface{} {
|
||||
func and(left, right any) any {
|
||||
if !toBool(left) {
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
func toBool(i interface{}) bool {
|
||||
func toBool(i any) bool {
|
||||
switch v := i.(type) {
|
||||
case nil:
|
||||
return false
|
||||
@@ -36,9 +38,9 @@ func toBool(i interface{}) bool {
|
||||
return v != 0
|
||||
case string:
|
||||
return v != ""
|
||||
case []interface{}:
|
||||
case []any:
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
return true
|
||||
default:
|
||||
panic("bool conversion failed")
|
||||
@@ -47,15 +49,15 @@ func toBool(i interface{}) bool {
|
||||
|
||||
// Arithmetic operators https://www.arangodb.com/docs/3.7/aql/operators.html#arithmetic-operators
|
||||
|
||||
func plus(left, right interface{}) float64 {
|
||||
func plus(left, right any) float64 {
|
||||
return toNumber(left) + toNumber(right)
|
||||
}
|
||||
|
||||
func minus(left, right interface{}) float64 {
|
||||
func minus(left, right any) float64 {
|
||||
return toNumber(left) - toNumber(right)
|
||||
}
|
||||
|
||||
func times(left, right interface{}) float64 {
|
||||
func times(left, right any) float64 {
|
||||
return round(toNumber(left) * toNumber(right))
|
||||
}
|
||||
|
||||
@@ -63,19 +65,20 @@ func round(r float64) float64 {
|
||||
return math.Round(r*100000) / 100000
|
||||
}
|
||||
|
||||
func div(left, right interface{}) float64 {
|
||||
func div(left, right any) float64 {
|
||||
b := toNumber(right)
|
||||
if b == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return round(toNumber(left) / b)
|
||||
}
|
||||
|
||||
func mod(left, right interface{}) float64 {
|
||||
func mod(left, right any) float64 {
|
||||
return math.Mod(toNumber(left), toNumber(right))
|
||||
}
|
||||
|
||||
func toNumber(i interface{}) float64 {
|
||||
func toNumber(i any) float64 {
|
||||
switch v := i.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
@@ -83,6 +86,7 @@ func toNumber(i interface{}) float64 {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
case float64:
|
||||
switch {
|
||||
@@ -91,22 +95,25 @@ func toNumber(i interface{}) float64 {
|
||||
case math.IsInf(v, 0):
|
||||
return 0
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return f
|
||||
case []interface{}:
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return 0
|
||||
}
|
||||
if len(v) == 1 {
|
||||
return toNumber(v[0])
|
||||
}
|
||||
|
||||
return 0
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
return 0
|
||||
default:
|
||||
panic("number conversion error")
|
||||
@@ -116,7 +123,7 @@ func toNumber(i interface{}) float64 {
|
||||
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
|
||||
// Order https://www.arangodb.com/docs/3.7/aql/fundamentals-type-value-order.html
|
||||
|
||||
func eq(left, right interface{}) bool {
|
||||
func eq(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return false
|
||||
@@ -126,15 +133,15 @@ func eq(left, right interface{}) bool {
|
||||
return true
|
||||
case bool, float64, string:
|
||||
return left == right
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -146,13 +153,14 @@ func eq(left, right interface{}) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -164,17 +172,18 @@ func eq(left, right interface{}) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func ne(left, right interface{}) bool {
|
||||
func ne(left, right any) bool {
|
||||
return !eq(left, right)
|
||||
}
|
||||
|
||||
func lt(left, right interface{}) bool {
|
||||
func lt(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV < rightV
|
||||
@@ -190,15 +199,15 @@ func lt(left, right interface{}) bool {
|
||||
return l < right.(float64)
|
||||
case string:
|
||||
return l < right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -210,13 +219,14 @@ func lt(left, right interface{}) bool {
|
||||
return lt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -228,16 +238,17 @@ func lt(left, right interface{}) bool {
|
||||
return lt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func keys(l map[string]interface{}, ro map[string]interface{}) []string {
|
||||
func keys(l map[string]any, ro map[string]any) []string {
|
||||
var keys []string
|
||||
seen := map[string]bool{}
|
||||
for _, a := range []map[string]interface{}{l, ro} {
|
||||
for _, a := range []map[string]any{l, ro} {
|
||||
for k := range a {
|
||||
if _, ok := seen[k]; !ok {
|
||||
seen[k] = true
|
||||
@@ -246,10 +257,11 @@ func keys(l map[string]interface{}, ro map[string]interface{}) []string {
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func gt(left, right interface{}) bool {
|
||||
func gt(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV > rightV
|
||||
@@ -265,15 +277,15 @@ func gt(left, right interface{}) bool {
|
||||
return l > right.(float64)
|
||||
case string:
|
||||
return l > right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -285,13 +297,14 @@ func gt(left, right interface{}) bool {
|
||||
return gt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -303,13 +316,14 @@ func gt(left, right interface{}) bool {
|
||||
return gt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func le(left, right interface{}) bool {
|
||||
func le(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV <= rightV
|
||||
@@ -325,15 +339,15 @@ func le(left, right interface{}) bool {
|
||||
return l <= right.(float64)
|
||||
case string:
|
||||
return l <= right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -345,13 +359,14 @@ func le(left, right interface{}) bool {
|
||||
return le(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -363,13 +378,14 @@ func le(left, right interface{}) bool {
|
||||
return lt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func ge(left, right interface{}) bool {
|
||||
func ge(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV >= rightV
|
||||
@@ -385,15 +401,15 @@ func ge(left, right interface{}) bool {
|
||||
return l >= right.(float64)
|
||||
case string:
|
||||
return l >= right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -405,13 +421,14 @@ func ge(left, right interface{}) bool {
|
||||
return ge(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -423,14 +440,15 @@ func ge(left, right interface{}) bool {
|
||||
return gt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func in(left, right interface{}) bool {
|
||||
a, ok := right.([]interface{})
|
||||
func in(left, right any) bool {
|
||||
a, ok := right.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -439,23 +457,25 @@ func in(left, right interface{}) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func like(left, right interface{}) (bool, error) {
|
||||
func like(left, right any) (bool, error) {
|
||||
return match(right.(string), left.(string))
|
||||
}
|
||||
|
||||
func regexMatch(left, right interface{}) (bool, error) {
|
||||
func regexMatch(left, right any) (bool, error) {
|
||||
return regexp.Match(right.(string), []byte(left.(string)))
|
||||
}
|
||||
|
||||
func regexNonMatch(left, right interface{}) (bool, error) {
|
||||
func regexNonMatch(left, right any) (bool, error) {
|
||||
m, err := regexp.Match(right.(string), []byte(left.(string)))
|
||||
|
||||
return !m, err
|
||||
}
|
||||
|
||||
func typeValue(v interface{}) int {
|
||||
func typeValue(v any) int {
|
||||
switch v.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
@@ -465,9 +485,9 @@ func typeValue(v interface{}) int {
|
||||
return 2
|
||||
case string:
|
||||
return 3
|
||||
case []interface{}:
|
||||
case []any:
|
||||
return 4
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
return 5
|
||||
default:
|
||||
panic("unknown type")
|
||||
@@ -476,22 +496,25 @@ func typeValue(v interface{}) int {
|
||||
|
||||
// Ternary operator https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
|
||||
|
||||
func ternary(left, middle, right interface{}) interface{} {
|
||||
func ternary(left, middle, right any) any {
|
||||
if toBool(left) {
|
||||
if middle != nil {
|
||||
return middle
|
||||
}
|
||||
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
// Range operators https://www.arangodb.com/docs/3.7/aql/operators.html#range-operator
|
||||
|
||||
func aqlrange(left, right interface{}) []float64 {
|
||||
func aqlrange(left, right any) []float64 {
|
||||
var v []float64
|
||||
for i := int(left.(float64)); i <= int(right.(float64)); i++ {
|
||||
v = append(v, float64(i))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func (p *Parser) Parse(aql string) (t *Tree, err error) {
|
||||
err = fmt.Errorf("%s", r)
|
||||
}
|
||||
}()
|
||||
// Setup the input
|
||||
// Set up the input
|
||||
inputStream := antlr.NewInputStream(aql)
|
||||
|
||||
errorListener := &errorListener{}
|
||||
@@ -52,7 +52,7 @@ type Tree struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
|
||||
func (t *Tree) Eval(values map[string]any) (i any, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("%s", r)
|
||||
@@ -65,6 +65,7 @@ func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
|
||||
if interpreter.errs != nil {
|
||||
return nil, interpreter.errs[0]
|
||||
}
|
||||
|
||||
return interpreter.stack[0], nil
|
||||
}
|
||||
|
||||
@@ -103,7 +104,7 @@ type errorListener struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
|
||||
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
|
||||
el.errs = append(el.errs, fmt.Errorf("line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package caql
|
||||
package caql_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/caql"
|
||||
)
|
||||
|
||||
type MockSearcher struct{}
|
||||
@@ -13,11 +15,13 @@ func (m MockSearcher) Search(_ string) (ids []string, err error) {
|
||||
}
|
||||
|
||||
func TestParseSAQLEval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saql string
|
||||
wantRebuild string
|
||||
wantValue interface{}
|
||||
wantValue any
|
||||
wantParseErr bool
|
||||
wantRebuildErr bool
|
||||
wantEvalErr bool
|
||||
@@ -89,15 +93,15 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
// {name: "String 9", saql: `'this is a longer string.'`, wantRebuild: `"this is a longer string."`, wantValue: "this is a longer string."},
|
||||
// {name: "String 10", saql: `'the path separator on Windows is \\'`, wantRebuild: `"the path separator on Windows is \\"`, wantValue: `the path separator on Windows is \`},
|
||||
|
||||
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []interface{}{}},
|
||||
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []interface{}{true}},
|
||||
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}},
|
||||
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []any{}},
|
||||
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []any{true}},
|
||||
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
|
||||
{
|
||||
name: "Array 4", saql: `[-99, "yikes!", [false, ["no"], []], 1]`, wantRebuild: `[-99, "yikes!", [false, ["no"], []], 1]`,
|
||||
wantValue: []interface{}{-99.0, "yikes!", []interface{}{false, []interface{}{"no"}, []interface{}{}}, float64(1)},
|
||||
wantValue: []any{-99.0, "yikes!", []any{false, []any{"no"}, []any{}}, float64(1)},
|
||||
},
|
||||
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []interface{}{[]interface{}{"fox", "marshal"}}},
|
||||
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}},
|
||||
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []any{[]any{"fox", "marshal"}}},
|
||||
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
|
||||
|
||||
{name: "Array Error 1", saql: "(1,2,3)", wantParseErr: true},
|
||||
{name: "Array Access 1", saql: "u.friends[0]", wantRebuild: "u.friends[0]", wantValue: 7, values: `{"u": {"friends": [7,8,9]}}`},
|
||||
@@ -105,14 +109,14 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
{name: "Array Access 3", saql: "u.friends[-1]", wantRebuild: "u.friends[-1]", wantValue: 9, values: `{"u": {"friends": [7,8,9]}}`},
|
||||
{name: "Array Access 4", saql: "u.friends[-2]", wantRebuild: "u.friends[-2]", wantValue: 8, values: `{"u": {"friends": [7,8,9]}}`},
|
||||
|
||||
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]interface{}{}},
|
||||
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]interface{}{"return": float64(1)}},
|
||||
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]interface{}{"return": float64(1)}},
|
||||
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]any{}},
|
||||
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]any{"return": float64(1)}},
|
||||
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]any{"return": float64(1)}},
|
||||
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
|
||||
// {"Object 8", "{`return`: 1}", `{"return": 1}`, true},
|
||||
// {"Object 7", "{´return´: 1}", `{"return": 1}`, true},
|
||||
{name: "Object Error 1: return is a keyword", saql: `{like: 1}`, wantParseErr: true},
|
||||
@@ -272,7 +276,7 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
{name: "Arithmetic 17", saql: `23 * {}`, wantRebuild: `23 * {}`, wantValue: 0},
|
||||
{name: "Arithmetic 18", saql: `5 * [7]`, wantRebuild: `5 * [7]`, wantValue: 35},
|
||||
{name: "Arithmetic 19", saql: `24 / "12"`, wantRebuild: `24 / "12"`, wantValue: 2},
|
||||
{name: "Arithmetic Error 1: Divison by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
|
||||
{name: "Arithmetic Error 1: Division by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
|
||||
|
||||
// https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
|
||||
{name: "Ternary 1", saql: `u.age > 15 || u.active == true ? u.userId : null`, wantRebuild: `u.age > 15 OR u.active == true ? u.userId : null`, wantValue: 45, values: `{"u": {"active": true, "age": 2, "userId": 45}}`},
|
||||
@@ -287,20 +291,24 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
{name: "Security 2", saql: `doc.value == 1 || true INSERT {foo: "bar"} IN collection //`, wantParseErr: true},
|
||||
|
||||
// https://www.arangodb.com/docs/3.7/aql/operators.html#operator-precedence
|
||||
{name: "Precendence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
|
||||
{name: "Precedence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
parser := &Parser{
|
||||
tt := tt
|
||||
parser := &caql.Parser{
|
||||
Searcher: &MockSearcher{},
|
||||
}
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expr, err := parser.Parse(tt.saql)
|
||||
if (err != nil) != tt.wantParseErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
if expr != nil {
|
||||
t.Error(expr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -311,6 +319,7 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
if (err != nil) != tt.wantRebuildErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -320,18 +329,19 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
|
||||
}
|
||||
|
||||
var myJson map[string]interface{}
|
||||
var myJSON map[string]any
|
||||
if tt.values != "" {
|
||||
err = json.Unmarshal([]byte(tt.values), &myJson)
|
||||
err = json.Unmarshal([]byte(tt.values), &myJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
value, err := expr.Eval(myJson)
|
||||
value, err := expr.Eval(myJSON)
|
||||
if (err != nil) != tt.wantEvalErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
57
caql/set.go
57
caql/set.go
@@ -22,19 +22,18 @@
|
||||
|
||||
package caql
|
||||
|
||||
import "sort"
|
||||
|
||||
type (
|
||||
Set struct {
|
||||
hash map[interface{}]nothing
|
||||
}
|
||||
|
||||
nothing struct{}
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Create a new set
|
||||
func New(initial ...interface{}) *Set {
|
||||
s := &Set{make(map[interface{}]nothing)}
|
||||
type Set struct {
|
||||
hash map[any]nothing
|
||||
}
|
||||
|
||||
type nothing struct{}
|
||||
|
||||
func NewSet(initial ...any) *Set {
|
||||
s := &Set{make(map[any]nothing)}
|
||||
|
||||
for _, v := range initial {
|
||||
s.Insert(v)
|
||||
@@ -43,9 +42,8 @@ func New(initial ...interface{}) *Set {
|
||||
return s
|
||||
}
|
||||
|
||||
// Find the difference between two sets
|
||||
func (s *Set) Difference(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
|
||||
for k := range s.hash {
|
||||
if _, exists := set.hash[k]; !exists {
|
||||
@@ -56,27 +54,18 @@ func (s *Set) Difference(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
// Call f for each item in the set
|
||||
func (s *Set) Do(f func(interface{})) {
|
||||
for k := range s.hash {
|
||||
f(k)
|
||||
}
|
||||
}
|
||||
|
||||
// Test to see whether or not the element is in the set
|
||||
func (s *Set) Has(element interface{}) bool {
|
||||
func (s *Set) Has(element any) bool {
|
||||
_, exists := s.hash[element]
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// Add an element to the set
|
||||
func (s *Set) Insert(element interface{}) {
|
||||
func (s *Set) Insert(element any) {
|
||||
s.hash[element] = nothing{}
|
||||
}
|
||||
|
||||
// Find the intersection of two sets
|
||||
func (s *Set) Intersection(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
|
||||
for k := range s.hash {
|
||||
if _, exists := set.hash[k]; exists {
|
||||
@@ -87,23 +76,20 @@ func (s *Set) Intersection(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
// Return the number of items in the set
|
||||
func (s *Set) Len() int {
|
||||
return len(s.hash)
|
||||
}
|
||||
|
||||
// Test whether or not this set is a proper subset of "set"
|
||||
func (s *Set) ProperSubsetOf(set *Set) bool {
|
||||
return s.SubsetOf(set) && s.Len() < set.Len()
|
||||
}
|
||||
|
||||
// Remove an element from the set
|
||||
func (s *Set) Remove(element interface{}) {
|
||||
func (s *Set) Remove(element any) {
|
||||
delete(s.hash, element)
|
||||
}
|
||||
|
||||
func (s *Set) Minus(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
for k := range s.hash {
|
||||
n[k] = nothing{}
|
||||
}
|
||||
@@ -115,7 +101,6 @@ func (s *Set) Minus(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
// Test whether or not this set is a subset of "set"
|
||||
func (s *Set) SubsetOf(set *Set) bool {
|
||||
if s.Len() > set.Len() {
|
||||
return false
|
||||
@@ -125,12 +110,12 @@ func (s *Set) SubsetOf(set *Set) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Find the union of two sets
|
||||
func (s *Set) Union(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
|
||||
for k := range s.hash {
|
||||
n[k] = nothing{}
|
||||
@@ -142,8 +127,8 @@ func (s *Set) Union(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
func (s *Set) Values() []interface{} {
|
||||
values := []interface{}{}
|
||||
func (s *Set) Values() []any {
|
||||
values := []any{}
|
||||
|
||||
for k := range s.hash {
|
||||
values = append(values, k)
|
||||
|
||||
@@ -27,7 +27,9 @@ import (
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
s := New()
|
||||
t.Parallel()
|
||||
|
||||
s := NewSet()
|
||||
|
||||
s.Insert(5)
|
||||
|
||||
@@ -50,8 +52,8 @@ func Test(t *testing.T) {
|
||||
}
|
||||
|
||||
// Difference
|
||||
s1 := New(1, 2, 3, 4, 5, 6)
|
||||
s2 := New(4, 5, 6)
|
||||
s1 := NewSet(1, 2, 3, 4, 5, 6)
|
||||
s2 := NewSet(4, 5, 6)
|
||||
s3 := s1.Difference(s2)
|
||||
|
||||
if s3.Len() != 3 {
|
||||
@@ -73,7 +75,7 @@ func Test(t *testing.T) {
|
||||
}
|
||||
|
||||
// Union
|
||||
s4 := New(7, 8, 9)
|
||||
s4 := NewSet(7, 8, 9)
|
||||
s3 = s2.Union(s4)
|
||||
|
||||
if s3.Len() != 6 {
|
||||
@@ -92,5 +94,4 @@ func Test(t *testing.T) {
|
||||
if s1.ProperSubsetOf(s1) {
|
||||
t.Errorf("set should not be a subset of itself")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -39,8 +39,10 @@ func unquote(s string) (string, error) {
|
||||
buf = append(buf, s[i])
|
||||
}
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
if quote != '"' && quote != '\'' {
|
||||
@@ -75,5 +77,6 @@ func unquote(s string) (string, error) {
|
||||
buf = append(buf, runeTmp[:n]...)
|
||||
}
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
@@ -8,26 +8,25 @@
|
||||
package caql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type quoteTest struct {
|
||||
in string
|
||||
out string
|
||||
ascii string
|
||||
graphic string
|
||||
in string
|
||||
out string
|
||||
}
|
||||
|
||||
var quotetests = []quoteTest{
|
||||
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`, ascii: `"\a\b\f\r\n\t\v"`, graphic: `"\a\b\f\r\n\t\v"`},
|
||||
{"\\", `"\\"`, `"\\"`, `"\\"`},
|
||||
{"abc\xffdef", `"abc\xffdef"`, `"abc\xffdef"`, `"abc\xffdef"`},
|
||||
{"\u263a", `"☺"`, `"\u263a"`, `"☺"`},
|
||||
{"\U0010ffff", `"\U0010ffff"`, `"\U0010ffff"`, `"\U0010ffff"`},
|
||||
{"\x04", `"\x04"`, `"\x04"`, `"\x04"`},
|
||||
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`},
|
||||
{"\\", `"\\"`},
|
||||
{"abc\xffdef", `"abc\xffdef"`},
|
||||
{"\u263a", `"☺"`},
|
||||
{"\U0010ffff", `"\U0010ffff"`},
|
||||
{"\x04", `"\x04"`},
|
||||
// Some non-printable but graphic runes. Final column is double-quoted.
|
||||
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`, `"!\u00a0!\u2000!\u3000!"`, "\"!\u00a0!\u2000!\u3000!\""},
|
||||
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`},
|
||||
}
|
||||
|
||||
type unQuoteTest struct {
|
||||
@@ -104,6 +103,8 @@ var misquoted = []string{
|
||||
}
|
||||
|
||||
func TestUnquote(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range unquotetests {
|
||||
if out, err := unquote(tt.in); err != nil || out != tt.out {
|
||||
t.Errorf("unquote(%#q) = %q, %v want %q, nil", tt.in, out, err, tt.out)
|
||||
@@ -118,7 +119,7 @@ func TestUnquote(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, s := range misquoted {
|
||||
if out, err := unquote(s); out != "" || err != strconv.ErrSyntax {
|
||||
if out, err := unquote(s); out != "" || !errors.Is(err, strconv.ErrSyntax) {
|
||||
t.Errorf("unquote(%#q) = %q, %v want %q, %v", s, out, err, "", strconv.ErrSyntax)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ Pattern:
|
||||
// using the star
|
||||
if ok && (len(t) == 0 || len(pattern) > 0) {
|
||||
name = t
|
||||
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
@@ -64,6 +65,7 @@ Pattern:
|
||||
continue
|
||||
}
|
||||
name = t
|
||||
|
||||
continue Pattern
|
||||
}
|
||||
if err != nil {
|
||||
@@ -79,8 +81,10 @@ Pattern:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return len(name) == 0, nil
|
||||
}
|
||||
|
||||
@@ -104,6 +108,7 @@ Scan:
|
||||
break Scan
|
||||
}
|
||||
}
|
||||
|
||||
return star, pattern[0:i], pattern[i:]
|
||||
}
|
||||
|
||||
@@ -120,7 +125,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
|
||||
failed = true
|
||||
}
|
||||
switch chunk[0] {
|
||||
|
||||
case '_':
|
||||
if !failed {
|
||||
if s[0] == '/' {
|
||||
@@ -130,14 +134,13 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
|
||||
s = s[n:]
|
||||
}
|
||||
chunk = chunk[1:]
|
||||
|
||||
case '\\':
|
||||
chunk = chunk[1:]
|
||||
if len(chunk) == 0 {
|
||||
return "", false, ErrBadPattern
|
||||
}
|
||||
fallthrough
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
if !failed {
|
||||
if chunk[0] != s[0] {
|
||||
@@ -151,5 +154,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
|
||||
if failed {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
return s, true, nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
|
||||
package caql
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MatchTest struct {
|
||||
pattern, s string
|
||||
@@ -41,9 +44,11 @@ var matchTests = []MatchTest{
|
||||
}
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range matchTests {
|
||||
ok, err := match(tt.pattern, tt.s)
|
||||
if ok != tt.match || err != tt.err {
|
||||
if ok != tt.match || !errors.Is(err, tt.err) {
|
||||
t.Errorf("match(%#q, %#q) = %v, %v want %v, %v", tt.pattern, tt.s, ok, err, tt.match, tt.err)
|
||||
}
|
||||
}
|
||||
|
||||
16
cmd/cmd.go
16
cmd/cmd.go
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/alecthomas/kong"
|
||||
kongyaml "github.com/alecthomas/kong-yaml"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst"
|
||||
@@ -62,7 +63,7 @@ func MapConfig(cli CLI) (*catalyst.Config, error) {
|
||||
roles = append(roles, role.Explodes(cli.AuthDefaultRoles)...)
|
||||
roles = role.Explodes(role.Strings(roles))
|
||||
|
||||
scopes := unique(append([]string{oidc.ScopeOpenID, "profile", "email"}, cli.OIDCScopes...))
|
||||
scopes := slices.Compact(append([]string{oidc.ScopeOpenID, "profile", "email"}, cli.OIDCScopes...))
|
||||
config := &catalyst.Config{
|
||||
IndexPath: cli.IndexPath,
|
||||
Network: cli.Network,
|
||||
@@ -83,17 +84,6 @@ func MapConfig(cli CLI) (*catalyst.Config, error) {
|
||||
Bus: &bus.Config{Host: cli.EmitterIOHost, Key: cli.EmitterIORKey, APIUrl: cli.CatalystAddress + "/api"},
|
||||
InitialAPIKey: cli.InitialAPIKey,
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func unique(l []string) []string {
|
||||
keys := make(map[string]bool)
|
||||
var list []string
|
||||
for _, entry := range l {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = true
|
||||
list = append(list, entry)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
16
cookie.go
16
cookie.go
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -22,15 +23,21 @@ func stateCookie(r *http.Request) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return stateCookie.Value, nil
|
||||
}
|
||||
|
||||
func setClaimsCookie(w http.ResponseWriter, claims map[string]interface{}) {
|
||||
b, _ := json.Marshal(claims)
|
||||
func setClaimsCookie(w http.ResponseWriter, claims map[string]any) {
|
||||
b, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{Name: userSessionCookie, Value: base64.StdEncoding.EncodeToString(b)})
|
||||
}
|
||||
|
||||
func claimsCookie(r *http.Request) (map[string]interface{}, bool, error) {
|
||||
func claimsCookie(r *http.Request) (map[string]any, bool, error) {
|
||||
userCookie, err := r.Cookie(userSessionCookie)
|
||||
if err != nil {
|
||||
return nil, true, nil
|
||||
@@ -41,9 +48,10 @@ func claimsCookie(r *http.Request) (map[string]interface{}, bool, error) {
|
||||
return nil, false, fmt.Errorf("could not decode cookie: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(b, &claims); err != nil {
|
||||
return nil, false, errors.New("claims not in session")
|
||||
}
|
||||
|
||||
return claims, false, err
|
||||
}
|
||||
|
||||
20
dag/dag.go
20
dag/dag.go
@@ -25,6 +25,9 @@ package dag
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type Graph struct {
|
||||
@@ -52,6 +55,7 @@ func (g *Graph) AddNode(name string) error {
|
||||
}
|
||||
g.outputs[name] = make(map[string]struct{})
|
||||
g.inputs[name] = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -61,6 +65,7 @@ func (g *Graph) AddNodes(names ...string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,7 +106,9 @@ func (g *Graph) Toposort() ([]string, error) {
|
||||
L = append(L, n)
|
||||
|
||||
ms := make([]string, len(outputs[n]))
|
||||
for _, k := range keys(outputs[n]) {
|
||||
keys := maps.Keys(outputs[n])
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
m := k
|
||||
// i := outputs[n][m]
|
||||
// ms[i-1] = m
|
||||
@@ -130,15 +137,6 @@ func (g *Graph) Toposort() ([]string, error) {
|
||||
return L, nil
|
||||
}
|
||||
|
||||
func keys(m map[string]struct{}) []string {
|
||||
var keys []string
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func (g *Graph) GetParents(id string) []string {
|
||||
var parents []string
|
||||
for node, targets := range g.outputs {
|
||||
@@ -147,6 +145,7 @@ func (g *Graph) GetParents(id string) []string {
|
||||
}
|
||||
}
|
||||
sort.Strings(parents)
|
||||
|
||||
return parents
|
||||
}
|
||||
|
||||
@@ -160,5 +159,6 @@ func (g *Graph) GetRoot() (string, error) {
|
||||
if len(roots) != 1 {
|
||||
return "", errors.New("more than one root")
|
||||
}
|
||||
|
||||
return roots[0], nil
|
||||
}
|
||||
|
||||
@@ -20,23 +20,17 @@
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
package dag
|
||||
package dag_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
func index(s []string, v string) int {
|
||||
for i, s := range s {
|
||||
if s == v {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
"github.com/SecurityBrewery/catalyst/dag"
|
||||
)
|
||||
|
||||
type Edge struct {
|
||||
From string
|
||||
@@ -44,13 +38,17 @@ type Edge struct {
|
||||
}
|
||||
|
||||
func TestDuplicatedNode(t *testing.T) {
|
||||
graph := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
graph := dag.NewGraph()
|
||||
assert.NoError(t, graph.AddNode("a"))
|
||||
assert.Error(t, graph.AddNode("a"))
|
||||
}
|
||||
|
||||
func TestWikipedia(t *testing.T) {
|
||||
graph := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
graph := dag.NewGraph()
|
||||
assert.NoError(t, graph.AddNodes("2", "3", "5", "7", "8", "9", "10", "11"))
|
||||
|
||||
edges := []Edge{
|
||||
@@ -79,27 +77,30 @@ func TestWikipedia(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, e := range edges {
|
||||
if i, j := index(result, e.From), index(result, e.To); i > j {
|
||||
if i, j := slices.Index(result, e.From), slices.Index(result, e.To); i > j {
|
||||
t.Errorf("dependency failed: not satisfy %v(%v) > %v(%v)", e.From, i, e.To, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCycle(t *testing.T) {
|
||||
graph := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
graph := dag.NewGraph()
|
||||
assert.NoError(t, graph.AddNodes("1", "2", "3"))
|
||||
|
||||
assert.NoError(t, graph.AddEdge("1", "2"))
|
||||
assert.NoError(t, graph.AddEdge("2", "3"))
|
||||
assert.NoError(t, graph.AddEdge("3", "1"))
|
||||
|
||||
_, err := graph.Toposort()
|
||||
if err == nil {
|
||||
if _, err := graph.Toposort(); err == nil {
|
||||
t.Errorf("closed path not detected in closed pathed graph")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraph_GetParents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type fields struct {
|
||||
nodes []string
|
||||
edges map[string]string
|
||||
@@ -117,8 +118,11 @@ func TestGraph_GetParents(t *testing.T) {
|
||||
{"parents 3", fields{nodes: []string{"1", "2", "3"}, edges: map[string]string{"1": "3", "2": "3"}}, args{id: "3"}, []string{"1", "2"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
g := dag.NewGraph()
|
||||
for _, node := range tt.fields.nodes {
|
||||
assert.NoError(t, g.AddNode(node))
|
||||
}
|
||||
@@ -134,7 +138,9 @@ func TestGraph_GetParents(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_AddNode(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
|
||||
v := "1"
|
||||
assert.NoError(t, dag.AddNode(v))
|
||||
@@ -143,7 +149,9 @@ func TestDAG_AddNode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_AddEdge(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
assert.NoError(t, dag.AddNode("0"))
|
||||
assert.NoError(t, dag.AddNode("1"))
|
||||
assert.NoError(t, dag.AddNode("2"))
|
||||
@@ -162,7 +170,9 @@ func TestDAG_AddEdge(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_GetParents(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
assert.NoError(t, dag.AddNode("1"))
|
||||
assert.NoError(t, dag.AddNode("2"))
|
||||
assert.NoError(t, dag.AddNode("3"))
|
||||
@@ -176,7 +186,9 @@ func TestDAG_GetParents(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_GetDescendants(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
assert.NoError(t, dag.AddNode("1"))
|
||||
assert.NoError(t, dag.AddNode("2"))
|
||||
assert.NoError(t, dag.AddNode("3"))
|
||||
@@ -188,7 +200,9 @@ func TestDAG_GetDescendants(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_Topsort(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
assert.NoError(t, dag.AddNode("1"))
|
||||
assert.NoError(t, dag.AddNode("2"))
|
||||
assert.NoError(t, dag.AddNode("3"))
|
||||
@@ -203,7 +217,9 @@ func TestDAG_Topsort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_TopsortStable(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
assert.NoError(t, dag.AddNode("1"))
|
||||
assert.NoError(t, dag.AddNode("2"))
|
||||
assert.NoError(t, dag.AddNode("3"))
|
||||
@@ -216,7 +232,9 @@ func TestDAG_TopsortStable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDAG_TopsortStable2(t *testing.T) {
|
||||
dag := NewGraph()
|
||||
t.Parallel()
|
||||
|
||||
dag := dag.NewGraph()
|
||||
|
||||
assert.NoError(t, dag.AddNodes("block-ioc", "block-iocs", "block-sender", "board", "fetch-iocs", "escalate", "extract-iocs", "mail-available", "search-email-gateway"))
|
||||
assert.NoError(t, dag.AddEdge("block-iocs", "block-ioc"))
|
||||
|
||||
@@ -23,7 +23,7 @@ func (db *Database) ArtifactGet(ctx context.Context, id int64, name string) (*mo
|
||||
FOR a in NOT_NULL(d.artifacts, [])
|
||||
FILTER a.name == @name
|
||||
RETURN a`
|
||||
cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]interface{}{
|
||||
cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]any{
|
||||
"@collection": TicketCollectionName,
|
||||
"ID": fmt.Sprint(id),
|
||||
"name": name,
|
||||
@@ -55,7 +55,8 @@ func (db *Database) ArtifactUpdate(ctx context.Context, id int64, name string, a
|
||||
LET newartifacts = APPEND(REMOVE_VALUE(d.artifacts, a), @artifact)
|
||||
UPDATE d WITH { "artifacts": newartifacts } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"@collection": TicketCollectionName,
|
||||
"ID": id,
|
||||
"name": name,
|
||||
@@ -69,7 +70,7 @@ func (db *Database) ArtifactUpdate(ctx context.Context, id int64, name string, a
|
||||
}
|
||||
|
||||
func (db *Database) EnrichArtifact(ctx context.Context, id int64, name string, enrichmentForm *model.EnrichmentForm) (*model.TicketWithTickets, error) {
|
||||
enrichment := model.Enrichment{time.Now().UTC(), enrichmentForm.Data, enrichmentForm.Name}
|
||||
enrichment := model.Enrichment{Created: time.Now().UTC(), Data: enrichmentForm.Data, Name: enrichmentForm.Name}
|
||||
|
||||
ticketFilterQuery, ticketFilterVars, err := db.Hooks.TicketWriteFilter(ctx)
|
||||
if err != nil {
|
||||
@@ -85,7 +86,8 @@ func (db *Database) EnrichArtifact(ctx context.Context, id int64, name string, e
|
||||
LET newartifacts = APPEND(REMOVE_VALUE(d.artifacts, a), MERGE(a, { "enrichments": newenrichments }))
|
||||
UPDATE d WITH { "artifacts": newartifacts } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"@collection": TicketCollectionName,
|
||||
"ID": id,
|
||||
"name": name,
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/generated/model"
|
||||
)
|
||||
|
||||
func toAutomation(doc *model.AutomationForm) interface{} {
|
||||
func toAutomation(doc *model.AutomationForm) *model.Automation {
|
||||
return &model.Automation{
|
||||
Image: doc.Image,
|
||||
Script: doc.Script,
|
||||
@@ -72,12 +72,13 @@ func (db *Database) AutomationUpdate(ctx context.Context, id string, automation
|
||||
|
||||
func (db *Database) AutomationDelete(ctx context.Context, id string) error {
|
||||
_, err := db.automationCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) AutomationList(ctx context.Context) ([]*model.AutomationResponse, error) {
|
||||
query := "FOR d IN @@collection SORT d._key ASC RETURN UNSET(d, 'script')"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": AutomationCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": AutomationCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -40,10 +40,12 @@ type Operation struct {
|
||||
Ids []driver.DocumentID
|
||||
}
|
||||
|
||||
var CreateOperation = &Operation{Type: bus.DatabaseEntryCreated}
|
||||
var ReadOperation = &Operation{Type: bus.DatabaseEntryRead}
|
||||
var (
|
||||
CreateOperation = &Operation{Type: bus.DatabaseEntryCreated}
|
||||
ReadOperation = &Operation{Type: bus.DatabaseEntryRead}
|
||||
)
|
||||
|
||||
func (db BusDatabase) Query(ctx context.Context, query string, vars map[string]interface{}, operation *Operation) (cur driver.Cursor, logs *model.LogEntry, err error) {
|
||||
func (db *BusDatabase) Query(ctx context.Context, query string, vars map[string]any, operation *Operation) (cur driver.Cursor, logs *model.LogEntry, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
cur, err = db.internal.Query(ctx, query, vars)
|
||||
@@ -61,31 +63,31 @@ func (db BusDatabase) Query(ctx context.Context, query string, vars map[string]i
|
||||
return cur, logs, err
|
||||
}
|
||||
|
||||
func (db BusDatabase) Remove(ctx context.Context) (err error) {
|
||||
func (db *BusDatabase) Remove(ctx context.Context) (err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
return db.internal.Remove(ctx)
|
||||
}
|
||||
|
||||
func (db BusDatabase) Collection(ctx context.Context, name string) (col driver.Collection, err error) {
|
||||
func (db *BusDatabase) Collection(ctx context.Context, name string) (col driver.Collection, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
return db.internal.Collection(ctx, name)
|
||||
}
|
||||
|
||||
type Collection struct {
|
||||
type Collection[T any] struct {
|
||||
internal driver.Collection
|
||||
db *BusDatabase
|
||||
}
|
||||
|
||||
func NewCollection(internal driver.Collection, db *BusDatabase) *Collection {
|
||||
return &Collection{internal: internal, db: db}
|
||||
func NewCollection[T any](internal driver.Collection, db *BusDatabase) *Collection[T] {
|
||||
return &Collection[T]{internal: internal, db: db}
|
||||
}
|
||||
|
||||
func (c Collection) CreateDocument(ctx, newctx context.Context, key string, document interface{}) (meta driver.DocumentMeta, err error) {
|
||||
func (c *Collection[T]) CreateDocument(ctx, newctx context.Context, key string, document *T) (meta driver.DocumentMeta, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
meta, err = c.internal.CreateDocument(newctx, &Keyed{Key: key, Doc: document})
|
||||
meta, err = c.internal.CreateDocument(newctx, &Keyed[T]{Key: key, Doc: document})
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
@@ -94,10 +96,11 @@ func (c Collection) CreateDocument(ctx, newctx context.Context, key string, docu
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (c Collection) CreateEdge(ctx, newctx context.Context, edge *driver.EdgeDocument) (meta driver.DocumentMeta, err error) {
|
||||
func (c *Collection[T]) CreateEdge(ctx, newctx context.Context, edge *driver.EdgeDocument) (meta driver.DocumentMeta, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
meta, err = c.internal.CreateDocument(newctx, edge)
|
||||
@@ -109,10 +112,11 @@ func (c Collection) CreateEdge(ctx, newctx context.Context, edge *driver.EdgeDoc
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (c Collection) CreateEdges(ctx context.Context, edges []*driver.EdgeDocument) (meta driver.DocumentMetaSlice, err error) {
|
||||
func (c *Collection[T]) CreateEdges(ctx context.Context, edges []*driver.EdgeDocument) (meta driver.DocumentMetaSlice, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
metas, errs, err := c.internal.CreateDocuments(ctx, edges)
|
||||
@@ -136,13 +140,13 @@ func (c Collection) CreateEdges(ctx context.Context, edges []*driver.EdgeDocumen
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func (c Collection) DocumentExists(ctx context.Context, id string) (exists bool, err error) {
|
||||
func (c *Collection[T]) DocumentExists(ctx context.Context, id string) (exists bool, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
return c.internal.DocumentExists(ctx, id)
|
||||
}
|
||||
|
||||
func (c Collection) ReadDocument(ctx context.Context, key string, result interface{}) (meta driver.DocumentMeta, err error) {
|
||||
func (c *Collection[T]) ReadDocument(ctx context.Context, key string, result *T) (meta driver.DocumentMeta, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
meta, err = c.internal.ReadDocument(ctx, key, result)
|
||||
@@ -150,7 +154,7 @@ func (c Collection) ReadDocument(ctx context.Context, key string, result interfa
|
||||
return
|
||||
}
|
||||
|
||||
func (c Collection) UpdateDocument(ctx context.Context, key string, update interface{}) (meta driver.DocumentMeta, err error) {
|
||||
func (c *Collection[T]) UpdateDocument(ctx context.Context, key string, update any) (meta driver.DocumentMeta, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
meta, err = c.internal.UpdateDocument(ctx, key, update)
|
||||
@@ -161,7 +165,7 @@ func (c Collection) UpdateDocument(ctx context.Context, key string, update inter
|
||||
return meta, c.db.bus.PublishDatabaseUpdate([]driver.DocumentID{meta.ID}, bus.DatabaseEntryUpdated)
|
||||
}
|
||||
|
||||
func (c Collection) ReplaceDocument(ctx context.Context, key string, document interface{}) (meta driver.DocumentMeta, err error) {
|
||||
func (c *Collection[T]) ReplaceDocument(ctx context.Context, key string, document *T) (meta driver.DocumentMeta, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
meta, err = c.internal.ReplaceDocument(ctx, key, document)
|
||||
@@ -172,13 +176,13 @@ func (c Collection) ReplaceDocument(ctx context.Context, key string, document in
|
||||
return meta, c.db.bus.PublishDatabaseUpdate([]driver.DocumentID{meta.ID}, bus.DatabaseEntryUpdated)
|
||||
}
|
||||
|
||||
func (c Collection) RemoveDocument(ctx context.Context, formatInt string) (meta driver.DocumentMeta, err error) {
|
||||
func (c *Collection[T]) RemoveDocument(ctx context.Context, formatInt string) (meta driver.DocumentMeta, err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
return c.internal.RemoveDocument(ctx, formatInt)
|
||||
}
|
||||
|
||||
func (c Collection) Truncate(ctx context.Context) (err error) {
|
||||
func (c *Collection[T]) Truncate(ctx context.Context) (err error) {
|
||||
defer func() { err = toHTTPErr(err) }()
|
||||
|
||||
return c.internal.Truncate(ctx)
|
||||
@@ -190,7 +194,9 @@ func toHTTPErr(err error) error {
|
||||
if errors.As(err, &ae) {
|
||||
return &api.HTTPError{Status: ae.Code, Internal: err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/role"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
userContextKey = "user"
|
||||
groupContextKey = "groups"
|
||||
userContextKey contextKey = "user"
|
||||
groupContextKey contextKey = "groups"
|
||||
)
|
||||
|
||||
func SetContext(r *http.Request, user *model.UserResponse) *http.Request {
|
||||
@@ -25,10 +27,12 @@ func SetGroupContext(r *http.Request, groups []string) *http.Request {
|
||||
|
||||
func UserContext(ctx context.Context, user *model.UserResponse) context.Context {
|
||||
user.Roles = role.Strings(role.Explodes(user.Roles))
|
||||
|
||||
return context.WithValue(ctx, userContextKey, user)
|
||||
}
|
||||
|
||||
func UserFromContext(ctx context.Context) (*model.UserResponse, bool) {
|
||||
u, ok := ctx.Value(userContextKey).(*model.UserResponse)
|
||||
|
||||
return u, ok
|
||||
}
|
||||
|
||||
@@ -2,18 +2,18 @@ package busdb
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Keyed struct {
|
||||
type Keyed[T any] struct {
|
||||
Key string
|
||||
Doc interface{}
|
||||
Doc *T
|
||||
}
|
||||
|
||||
func (p Keyed) MarshalJSON() ([]byte, error) {
|
||||
func (p *Keyed[T]) MarshalJSON() ([]byte, error) {
|
||||
b, err := json.Marshal(p.Doc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var m map[string]interface{}
|
||||
var m map[string]any
|
||||
err = json.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -3,6 +3,7 @@ package busdb
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/arangodb/go-driver"
|
||||
@@ -45,7 +46,12 @@ func (db *BusDatabase) LogBatchCreate(ctx context.Context, logentries []*model.L
|
||||
}
|
||||
}
|
||||
if ids != nil {
|
||||
go db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryCreated)
|
||||
go func() {
|
||||
err := db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryCreated)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
_, errs, err := db.logCollection.CreateDocuments(ctx, logentries)
|
||||
@@ -62,7 +68,7 @@ func (db *BusDatabase) LogBatchCreate(ctx context.Context, logentries []*model.L
|
||||
|
||||
func (db *BusDatabase) LogList(ctx context.Context, reference string) ([]*model.LogEntry, error) {
|
||||
query := "FOR d IN @@collection FILTER d.reference == @reference SORT d.created DESC RETURN d"
|
||||
cursor, err := db.internal.Query(ctx, query, map[string]interface{}{
|
||||
cursor, err := db.internal.Query(ctx, query, map[string]any{
|
||||
"@collection": LogCollectionName,
|
||||
"reference": reference,
|
||||
})
|
||||
|
||||
@@ -72,12 +72,13 @@ func (db *Database) DashboardUpdate(ctx context.Context, id string, dashboard *m
|
||||
|
||||
func (db *Database) DashboardDelete(ctx context.Context, id string) error {
|
||||
_, err := db.dashboardCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) DashboardList(ctx context.Context) ([]*model.DashboardResponse, error) {
|
||||
query := "FOR d IN @@collection RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": DashboardCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": DashboardCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -103,15 +104,16 @@ func (db *Database) parseWidgets(dashboard *model.Dashboard) error {
|
||||
|
||||
_, err := parser.Parse(widget.Aggregation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid aggregation query (%s): syntax error\n", widget.Aggregation)
|
||||
return fmt.Errorf("invalid aggregation query (%s): syntax error", widget.Aggregation)
|
||||
}
|
||||
|
||||
if widget.Filter != nil {
|
||||
_, err := parser.Parse(*widget.Filter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid filter query (%s): syntax error\n", *widget.Filter)
|
||||
return fmt.Errorf("invalid filter query (%s): syntax error", *widget.Filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/bus"
|
||||
"github.com/SecurityBrewery/catalyst/database/busdb"
|
||||
"github.com/SecurityBrewery/catalyst/database/migrations"
|
||||
"github.com/SecurityBrewery/catalyst/generated/model"
|
||||
"github.com/SecurityBrewery/catalyst/hooks"
|
||||
"github.com/SecurityBrewery/catalyst/index"
|
||||
)
|
||||
@@ -38,18 +39,18 @@ type Database struct {
|
||||
bus *bus.Bus
|
||||
Hooks *hooks.Hooks
|
||||
|
||||
templateCollection *busdb.Collection
|
||||
ticketCollection *busdb.Collection
|
||||
playbookCollection *busdb.Collection
|
||||
automationCollection *busdb.Collection
|
||||
userdataCollection *busdb.Collection
|
||||
userCollection *busdb.Collection
|
||||
tickettypeCollection *busdb.Collection
|
||||
jobCollection *busdb.Collection
|
||||
settingsCollection *busdb.Collection
|
||||
dashboardCollection *busdb.Collection
|
||||
templateCollection *busdb.Collection[model.TicketTemplate]
|
||||
ticketCollection *busdb.Collection[model.Ticket]
|
||||
playbookCollection *busdb.Collection[model.PlaybookTemplate]
|
||||
automationCollection *busdb.Collection[model.Automation]
|
||||
userdataCollection *busdb.Collection[model.UserData]
|
||||
userCollection *busdb.Collection[model.User]
|
||||
tickettypeCollection *busdb.Collection[model.TicketType]
|
||||
jobCollection *busdb.Collection[model.Job]
|
||||
settingsCollection *busdb.Collection[model.Settings]
|
||||
dashboardCollection *busdb.Collection[model.Dashboard]
|
||||
|
||||
relatedCollection *busdb.Collection
|
||||
relatedCollection *busdb.Collection[driver.EdgeDocument]
|
||||
// containsCollection *busdb.Collection
|
||||
}
|
||||
|
||||
@@ -145,17 +146,17 @@ func New(ctx context.Context, index *index.Index, bus *bus.Bus, hooks *hooks.Hoo
|
||||
bus: bus,
|
||||
Index: index,
|
||||
Hooks: hooks,
|
||||
templateCollection: busdb.NewCollection(templateCollection, hookedDB),
|
||||
ticketCollection: busdb.NewCollection(ticketCollection, hookedDB),
|
||||
playbookCollection: busdb.NewCollection(playbookCollection, hookedDB),
|
||||
automationCollection: busdb.NewCollection(automationCollection, hookedDB),
|
||||
relatedCollection: busdb.NewCollection(relatedCollection, hookedDB),
|
||||
userdataCollection: busdb.NewCollection(userdataCollection, hookedDB),
|
||||
userCollection: busdb.NewCollection(userCollection, hookedDB),
|
||||
tickettypeCollection: busdb.NewCollection(tickettypeCollection, hookedDB),
|
||||
jobCollection: busdb.NewCollection(jobCollection, hookedDB),
|
||||
settingsCollection: busdb.NewCollection(settingsCollection, hookedDB),
|
||||
dashboardCollection: busdb.NewCollection(dashboardCollection, hookedDB),
|
||||
templateCollection: busdb.NewCollection[model.TicketTemplate](templateCollection, hookedDB),
|
||||
ticketCollection: busdb.NewCollection[model.Ticket](ticketCollection, hookedDB),
|
||||
playbookCollection: busdb.NewCollection[model.PlaybookTemplate](playbookCollection, hookedDB),
|
||||
automationCollection: busdb.NewCollection[model.Automation](automationCollection, hookedDB),
|
||||
userdataCollection: busdb.NewCollection[model.UserData](userdataCollection, hookedDB),
|
||||
userCollection: busdb.NewCollection[model.User](userCollection, hookedDB),
|
||||
tickettypeCollection: busdb.NewCollection[model.TicketType](tickettypeCollection, hookedDB),
|
||||
jobCollection: busdb.NewCollection[model.Job](jobCollection, hookedDB),
|
||||
settingsCollection: busdb.NewCollection[model.Settings](settingsCollection, hookedDB),
|
||||
dashboardCollection: busdb.NewCollection[model.Dashboard](dashboardCollection, hookedDB),
|
||||
relatedCollection: busdb.NewCollection[driver.EdgeDocument](relatedCollection, hookedDB),
|
||||
}
|
||||
|
||||
return db, nil
|
||||
@@ -194,16 +195,16 @@ func SetupDB(ctx context.Context, client driver.Client, dbName string) (driver.D
|
||||
}
|
||||
|
||||
func (db *Database) Truncate(ctx context.Context) {
|
||||
db.templateCollection.Truncate(ctx)
|
||||
db.ticketCollection.Truncate(ctx)
|
||||
db.playbookCollection.Truncate(ctx)
|
||||
db.automationCollection.Truncate(ctx)
|
||||
db.userdataCollection.Truncate(ctx)
|
||||
db.userCollection.Truncate(ctx)
|
||||
db.tickettypeCollection.Truncate(ctx)
|
||||
db.jobCollection.Truncate(ctx)
|
||||
db.relatedCollection.Truncate(ctx)
|
||||
db.settingsCollection.Truncate(ctx)
|
||||
db.dashboardCollection.Truncate(ctx)
|
||||
_ = db.templateCollection.Truncate(ctx)
|
||||
_ = db.ticketCollection.Truncate(ctx)
|
||||
_ = db.playbookCollection.Truncate(ctx)
|
||||
_ = db.automationCollection.Truncate(ctx)
|
||||
_ = db.userdataCollection.Truncate(ctx)
|
||||
_ = db.userCollection.Truncate(ctx)
|
||||
_ = db.tickettypeCollection.Truncate(ctx)
|
||||
_ = db.jobCollection.Truncate(ctx)
|
||||
_ = db.relatedCollection.Truncate(ctx)
|
||||
_ = db.settingsCollection.Truncate(ctx)
|
||||
_ = db.dashboardCollection.Truncate(ctx)
|
||||
// db.containsCollection.Truncate(ctx)
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func (db *Database) toJobResponse(ctx context.Context, key string, doc *model.Jo
|
||||
inspect, err := cli.ContainerInspect(ctx, key)
|
||||
if err != nil || inspect.State == nil {
|
||||
if update {
|
||||
db.JobUpdate(ctx, key, &model.JobUpdate{
|
||||
_, _ = db.JobUpdate(ctx, key, &model.JobUpdate{
|
||||
Status: doc.Status,
|
||||
Running: false,
|
||||
})
|
||||
@@ -46,7 +46,7 @@ func (db *Database) toJobResponse(ctx context.Context, key string, doc *model.Jo
|
||||
} else if doc.Status != inspect.State.Status {
|
||||
status = inspect.State.Status
|
||||
if update {
|
||||
db.JobUpdate(ctx, key, &model.JobUpdate{
|
||||
_, _ = db.JobUpdate(ctx, key, &model.JobUpdate{
|
||||
Status: status,
|
||||
Running: doc.Running,
|
||||
})
|
||||
@@ -107,7 +107,7 @@ func (db *Database) JobUpdate(ctx context.Context, id string, job *model.JobUpda
|
||||
func (db *Database) JobLogAppend(ctx context.Context, id string, logLine string) error {
|
||||
query := `LET d = DOCUMENT(@@collection, @ID)
|
||||
UPDATE d WITH { "log": CONCAT(NOT_NULL(d.log, ""), @logline) } IN @@collection`
|
||||
cur, _, err := db.Query(ctx, query, map[string]interface{}{
|
||||
cur, _, err := db.Query(ctx, query, map[string]any{
|
||||
"@collection": JobCollectionName,
|
||||
"ID": id,
|
||||
"logline": logLine,
|
||||
@@ -125,10 +125,10 @@ func (db *Database) JobLogAppend(ctx context.Context, id string, logLine string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) JobComplete(ctx context.Context, id string, out interface{}) error {
|
||||
func (db *Database) JobComplete(ctx context.Context, id string, out any) error {
|
||||
query := `LET d = DOCUMENT(@@collection, @ID)
|
||||
UPDATE d WITH { "output": @out, "status": "completed", "running": false } IN @@collection`
|
||||
cur, _, err := db.Query(ctx, query, map[string]interface{}{
|
||||
cur, _, err := db.Query(ctx, query, map[string]any{
|
||||
"@collection": JobCollectionName,
|
||||
"ID": id,
|
||||
"out": out,
|
||||
@@ -148,12 +148,13 @@ func (db *Database) JobComplete(ctx context.Context, id string, out interface{})
|
||||
|
||||
func (db *Database) JobDelete(ctx context.Context, id string) error {
|
||||
_, err := db.jobCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) JobList(ctx context.Context) ([]*model.JobResponse, error) {
|
||||
query := "FOR d IN @@collection RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": JobCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": JobCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -188,24 +189,24 @@ func publishJobMapping(id, automation string, contextStructs *model.Context, ori
|
||||
return publishJob(id, automation, contextStructs, origin, msg, db)
|
||||
}
|
||||
|
||||
func publishJob(id, automation string, contextStructs *model.Context, origin *model.Origin, payload map[string]interface{}, db *Database) error {
|
||||
func publishJob(id, automation string, contextStructs *model.Context, origin *model.Origin, payload map[string]any, db *Database) error {
|
||||
return db.bus.PublishJob(id, automation, payload, contextStructs, origin)
|
||||
}
|
||||
|
||||
func generatePayload(msgMapping map[string]string, contextStructs *model.Context) (map[string]interface{}, error) {
|
||||
contextJson, err := json.Marshal(contextStructs)
|
||||
func generatePayload(msgMapping map[string]string, contextStructs *model.Context) (map[string]any, error) {
|
||||
contextJSON, err := json.Marshal(contextStructs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
automationContext := map[string]interface{}{}
|
||||
err = json.Unmarshal(contextJson, &automationContext)
|
||||
automationContext := map[string]any{}
|
||||
err = json.Unmarshal(contextJSON, &automationContext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parser := caql.Parser{}
|
||||
msg := map[string]interface{}{}
|
||||
msg := map[string]any{}
|
||||
for arg, expr := range msgMapping {
|
||||
tree, err := parser.Parse(expr)
|
||||
if err != nil {
|
||||
@@ -218,5 +219,6 @@ func generatePayload(msgMapping map[string]string, contextStructs *model.Context
|
||||
}
|
||||
msg[arg] = v
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
@@ -32,34 +32,34 @@ func generateMigrations() ([]Migration, error) {
|
||||
|
||||
&createGraph{ID: "create-ticket-graph", Name: "Graph", EdgeDefinitions: []driver.EdgeDefinition{{Collection: "related", From: []string{"tickets"}, To: []string{"tickets"}}}},
|
||||
|
||||
&createDocument{ID: "create-template-default", Collection: "templates", Document: &busdb.Keyed{Key: "default", Doc: model.TicketTemplate{Schema: DefaultTemplateSchema, Name: "Default"}}},
|
||||
&createDocument{ID: "create-automation-vt.hash", Collection: "automations", Document: &busdb.Keyed{Key: "vt.hash", Doc: model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation}}},
|
||||
&createDocument{ID: "create-automation-comment", Collection: "automations", Document: &busdb.Keyed{Key: "comment", Doc: model.Automation{Image: "docker.io/python:3", Script: CommentAutomation}}},
|
||||
&createDocument{ID: "create-automation-hash.sha1", Collection: "automations", Document: &busdb.Keyed{Key: "hash.sha1", Doc: model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation}}},
|
||||
&createDocument{ID: "create-playbook-malware", Collection: "playbooks", Document: &busdb.Keyed{Key: "malware", Doc: model.PlaybookTemplate{Name: "Malware", Yaml: MalwarePlaybook}}},
|
||||
&createDocument{ID: "create-playbook-phishing", Collection: "playbooks", Document: &busdb.Keyed{Key: "phishing", Doc: model.PlaybookTemplate{Name: "Phishing", Yaml: PhishingPlaybook}}},
|
||||
&createDocument{ID: "create-tickettype-alert", Collection: "tickettypes", Document: &busdb.Keyed{Key: "alert", Doc: model.TicketType{Name: "Alerts", Icon: "mdi-alert", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument{ID: "create-tickettype-incident", Collection: "tickettypes", Document: &busdb.Keyed{Key: "incident", Doc: model.TicketType{Name: "Incidents", Icon: "mdi-radioactive", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument{ID: "create-tickettype-investigation", Collection: "tickettypes", Document: &busdb.Keyed{Key: "investigation", Doc: model.TicketType{Name: "Forensic Investigations", Icon: "mdi-fingerprint", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument{ID: "create-tickettype-hunt", Collection: "tickettypes", Document: &busdb.Keyed{Key: "hunt", Doc: model.TicketType{Name: "Threat Hunting", Icon: "mdi-target", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument[busdb.Keyed[model.TicketTemplate]]{ID: "create-template-default", Collection: "templates", Document: &busdb.Keyed[model.TicketTemplate]{Key: "default", Doc: &model.TicketTemplate{Schema: DefaultTemplateSchema, Name: "Default"}}},
|
||||
&createDocument[busdb.Keyed[model.Automation]]{ID: "create-automation-vt.hash", Collection: "automations", Document: &busdb.Keyed[model.Automation]{Key: "vt.hash", Doc: &model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation}}},
|
||||
&createDocument[busdb.Keyed[model.Automation]]{ID: "create-automation-comment", Collection: "automations", Document: &busdb.Keyed[model.Automation]{Key: "comment", Doc: &model.Automation{Image: "docker.io/python:3", Script: CommentAutomation}}},
|
||||
&createDocument[busdb.Keyed[model.Automation]]{ID: "create-automation-hash.sha1", Collection: "automations", Document: &busdb.Keyed[model.Automation]{Key: "hash.sha1", Doc: &model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation}}},
|
||||
&createDocument[busdb.Keyed[model.PlaybookTemplate]]{ID: "create-playbook-malware", Collection: "playbooks", Document: &busdb.Keyed[model.PlaybookTemplate]{Key: "malware", Doc: &model.PlaybookTemplate{Name: "Malware", Yaml: MalwarePlaybook}}},
|
||||
&createDocument[busdb.Keyed[model.PlaybookTemplate]]{ID: "create-playbook-phishing", Collection: "playbooks", Document: &busdb.Keyed[model.PlaybookTemplate]{Key: "phishing", Doc: &model.PlaybookTemplate{Name: "Phishing", Yaml: PhishingPlaybook}}},
|
||||
&createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-alert", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "alert", Doc: &model.TicketType{Name: "Alerts", Icon: "mdi-alert", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-incident", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "incident", Doc: &model.TicketType{Name: "Incidents", Icon: "mdi-radioactive", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-investigation", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "investigation", Doc: &model.TicketType{Name: "Forensic Investigations", Icon: "mdi-fingerprint", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
&createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-hunt", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "hunt", Doc: &model.TicketType{Name: "Threat Hunting", Icon: "mdi-target", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
|
||||
|
||||
&updateSchema{ID: "update-automation-collection-1", Name: "automations", DataType: "automation", Schema: `{"properties":{"image":{"type":"string"},"script":{"type":"string"}},"required":["image","script"],"type":"object"}`},
|
||||
&updateDocument{ID: "update-automation-vt.hash-1", Collection: "automations", Key: "vt.hash", Document: model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}},
|
||||
&updateDocument{ID: "update-automation-comment-1", Collection: "automations", Key: "comment", Document: model.Automation{Image: "docker.io/python:3", Script: CommentAutomation, Type: []string{"playbook"}}},
|
||||
&updateDocument{ID: "update-automation-hash.sha1-1", Collection: "automations", Key: "hash.sha1", Document: model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}},
|
||||
&updateDocument[model.Automation]{ID: "update-automation-vt.hash-1", Collection: "automations", Key: "vt.hash", Document: &model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}},
|
||||
&updateDocument[model.Automation]{ID: "update-automation-comment-1", Collection: "automations", Key: "comment", Document: &model.Automation{Image: "docker.io/python:3", Script: CommentAutomation, Type: []string{"playbook"}}},
|
||||
&updateDocument[model.Automation]{ID: "update-automation-hash.sha1-1", Collection: "automations", Key: "hash.sha1", Document: &model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}},
|
||||
|
||||
&createCollection{ID: "create-job-collection", Name: "jobs", DataType: "job", Schema: `{"properties":{"automation":{"type":"string"},"log":{"type":"string"},"payload":{},"origin":{"properties":{"artifact_origin":{"properties":{"artifact":{"type":"string"},"ticket_id":{"format":"int64","type":"integer"}},"required":["artifact","ticket_id"],"type":"object"},"task_origin":{"properties":{"playbook_id":{"type":"string"},"task_id":{"type":"string"},"ticket_id":{"format":"int64","type":"integer"}},"required":["playbook_id","task_id","ticket_id"],"type":"object"}},"type":"object"},"output":{"properties":{},"type":"object"},"running":{"type":"boolean"},"status":{"type":"string"}},"required":["automation","running","status"],"type":"object"}`},
|
||||
|
||||
&createDocument{ID: "create-playbook-simple", Collection: "playbooks", Document: &busdb.Keyed{Key: "simple", Doc: model.PlaybookTemplate{Name: "Simple", Yaml: SimplePlaybook}}},
|
||||
&createDocument[busdb.Keyed[model.PlaybookTemplate]]{ID: "create-playbook-simple", Collection: "playbooks", Document: &busdb.Keyed[model.PlaybookTemplate]{Key: "simple", Doc: &model.PlaybookTemplate{Name: "Simple", Yaml: SimplePlaybook}}},
|
||||
|
||||
&createCollection{ID: "create-settings-collection", Name: "settings", DataType: "settings", Schema: `{"type":"object","properties":{"artifactStates":{"title":"Artifact States","items":{"type":"object","properties":{"color":{"title":"Color","type":"string","enum":["error","info","success","warning"]},"icon":{"title":"Icon (https://materialdesignicons.com)","type":"string"},"id":{"title":"ID","type":"string"},"name":{"title":"Name","type":"string"}},"required":["id","name","icon"]},"type":"array"},"artifactKinds":{"title":"Artifact Kinds","items":{"type":"object","properties":{"color":{"title":"Color","type":"string","enum":["error","info","success","warning"]},"icon":{"title":"Icon (https://materialdesignicons.com)","type":"string"},"id":{"title":"ID","type":"string"},"name":{"title":"Name","type":"string"}},"required":["id","name","icon"]},"type":"array"},"timeformat":{"title":"Time Format","type":"string"}},"required":["timeformat","artifactKinds","artifactStates"]}`},
|
||||
&createDocument{ID: "create-settings-global", Collection: "settings", Document: &busdb.Keyed{Key: "global", Doc: model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "YYYY-MM-DDThh:mm:ss"}}},
|
||||
&createDocument[busdb.Keyed[model.Settings]]{ID: "create-settings-global", Collection: "settings", Document: &busdb.Keyed[model.Settings]{Key: "global", Doc: &model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "YYYY-MM-DDThh:mm:ss"}}},
|
||||
|
||||
&updateSchema{ID: "update-ticket-collection", Name: "tickets", DataType: "ticket", Schema: `{"properties":{"artifacts":{"items":{"properties":{"enrichments":{"additionalProperties":{"properties":{"created":{"format":"date-time","type":"string"},"data":{"example":{"hash":"b7a067a742c20d07a7456646de89bc2d408a1153"},"properties":{},"type":"object"},"name":{"example":"hash.sha1","type":"string"}},"required":["created","data","name"],"type":"object"},"type":"object"},"name":{"example":"2.2.2.2","type":"string"},"status":{"example":"Unknown","type":"string"},"type":{"type":"string"},"kind":{"type":"string"}},"required":["name"],"type":"object"},"type":"array"},"comments":{"items":{"properties":{"created":{"format":"date-time","type":"string"},"creator":{"type":"string"},"message":{"type":"string"}},"required":["created","creator","message"],"type":"object"},"type":"array"},"created":{"format":"date-time","type":"string"},"details":{"example":{"description":"my little incident"},"properties":{},"type":"object"},"files":{"items":{"properties":{"key":{"example":"myfile","type":"string"},"name":{"example":"notes.docx","type":"string"}},"required":["key","name"],"type":"object"},"type":"array"},"modified":{"format":"date-time","type":"string"},"name":{"example":"WannyCry","type":"string"},"owner":{"example":"bob","type":"string"},"playbooks":{"additionalProperties":{"properties":{"name":{"example":"Phishing","type":"string"},"tasks":{"additionalProperties":{"properties":{"automation":{"type":"string"},"closed":{"format":"date-time","type":"string"},"created":{"format":"date-time","type":"string"},"data":{"properties":{},"type":"object"},"done":{"type":"boolean"},"join":{"example":false,"type":"boolean"},"payload":{"additionalProperties":{"type":"string"},"type":"object"},"name":{"example":"Inform user","type":"string"},"next":{"additionalProperties":{"type":"string"},"type":"object"},"owner":{"type":"string"},"schema":{"properties":{},"type":"object"},"type":{"enum":["task","input","automation"],"example":"task","type":"string"}},"required":["created","done","name","type"],"type":"object"},"type":"object"}},"required":["name","tasks"],"type":"object"},"type":"object"},"read":{"example":["bob"],"items":{"type":"string"},"type":"array"},"references":{"items":{"properties":{"href":{"example":"https://cve.mitre.org/cgi-bin/cvename.cgi?name=cve-2017-0144","type":"string"},"name":{"example":"CVE-2017-0144","type":"string"}},"required":["href","name"],"type":"object"},"type":"array"},"schema":{"example":"{}","type":"string"},"status":{"example":"open","type":"string"},"type":{"example":"incident","type":"string"},"write":{"example":["alice"],"items":{"type":"string"},"type":"array"}},"required":["created","modified","name","schema","status","type"],"type":"object"}`},
|
||||
|
||||
&createCollection{ID: "create-dashboard-collection", Name: "dashboards", DataType: "dashboards", Schema: `{"type":"object","properties":{"name":{"type":"string"},"widgets":{"items":{"type":"object","properties":{"aggregation":{"type":"string"},"filter":{"type":"string"},"name":{"type":"string"},"type":{"enum":[ "bar", "line", "pie" ]},"width": { "type": "integer", "minimum": 1, "maximum": 12 }},"required":["name","aggregation", "type", "width"]},"type":"array"}},"required":["name","widgets"]}`},
|
||||
|
||||
&updateDocument{ID: "update-settings-global-1", Collection: "settings", Key: "global", Document: &model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "yyyy-MM-dd hh:mm:ss"}},
|
||||
&updateDocument[model.Settings]{ID: "update-settings-global-1", Collection: "settings", Key: "global", Document: &model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "yyyy-MM-dd hh:mm:ss"}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ func loadSchema(dataType, jsonschema string) (*driver.CollectionSchemaOptions, e
|
||||
ticketCollectionSchema := &driver.CollectionSchemaOptions{Level: driver.CollectionSchemaLevelStrict, Message: fmt.Sprintf("Validation of %s failed", dataType)}
|
||||
|
||||
err := ticketCollectionSchema.LoadRule([]byte(jsonschema))
|
||||
|
||||
return ticketCollectionSchema, err
|
||||
}
|
||||
|
||||
@@ -101,6 +102,7 @@ func PerformMigrations(ctx context.Context, db driver.Database) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -171,41 +173,43 @@ func (m *createGraph) Migrate(ctx context.Context, db driver.Database) error {
|
||||
_, err := db.CreateGraph(ctx, m.Name, &driver.CreateGraphOptions{
|
||||
EdgeDefinitions: m.EdgeDefinitions,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type createDocument struct {
|
||||
type createDocument[T any] struct {
|
||||
ID string
|
||||
Collection string
|
||||
Document interface{}
|
||||
Document *T
|
||||
}
|
||||
|
||||
func (m *createDocument) MID() string {
|
||||
func (m *createDocument[T]) MID() string {
|
||||
return m.ID
|
||||
}
|
||||
|
||||
func (m *createDocument) Migrate(ctx context.Context, driver driver.Database) error {
|
||||
func (m *createDocument[T]) Migrate(ctx context.Context, driver driver.Database) error {
|
||||
collection, err := driver.Collection(ctx, m.Collection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = collection.CreateDocument(ctx, m.Document)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type updateDocument struct {
|
||||
type updateDocument[T any] struct {
|
||||
ID string
|
||||
Collection string
|
||||
Key string
|
||||
Document interface{}
|
||||
Document *T
|
||||
}
|
||||
|
||||
func (m *updateDocument) MID() string {
|
||||
func (m *updateDocument[T]) MID() string {
|
||||
return m.ID
|
||||
}
|
||||
|
||||
func (m *updateDocument) Migrate(ctx context.Context, driver driver.Database) error {
|
||||
func (m *updateDocument[T]) Migrate(ctx context.Context, driver driver.Database) error {
|
||||
collection, err := driver.Collection(ctx, m.Collection)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -218,9 +222,11 @@ func (m *updateDocument) Migrate(ctx context.Context, driver driver.Database) er
|
||||
|
||||
if !exists {
|
||||
_, err = collection.CreateDocument(ctx, m.Document)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = collection.ReplaceDocument(ctx, m.Key, m.Document)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ type PlaybookYAML struct {
|
||||
type TaskYAML struct {
|
||||
Name string `yaml:"name"`
|
||||
Type string `yaml:"type"`
|
||||
Schema interface{} `yaml:"schema"`
|
||||
Schema any `yaml:"schema"`
|
||||
Automation string `yaml:"automation"`
|
||||
Payload map[string]string `yaml:"payload"`
|
||||
Next map[string]string `yaml:"next"`
|
||||
@@ -42,6 +42,7 @@ func toPlaybooks(docs []*model.PlaybookTemplateForm) (map[string]*model.Playbook
|
||||
playbooks[strcase.ToKebab(playbook.Name)] = playbook
|
||||
}
|
||||
}
|
||||
|
||||
return playbooks, nil
|
||||
}
|
||||
|
||||
@@ -53,11 +54,17 @@ func toPlaybook(doc *model.PlaybookTemplateForm) (*model.Playbook, error) {
|
||||
}
|
||||
for idx, task := range ticketPlaybook.Tasks {
|
||||
if task.Schema != nil {
|
||||
task.Schema = dyno.ConvertMapI2MapS(task.Schema).(map[string]interface{})
|
||||
schema, ok := dyno.ConvertMapI2MapS(task.Schema).(map[string]any)
|
||||
if ok {
|
||||
task.Schema = schema
|
||||
} else {
|
||||
return nil, errors.New("could not convert schema")
|
||||
}
|
||||
}
|
||||
task.Created = time.Now().UTC()
|
||||
ticketPlaybook.Tasks[idx] = task
|
||||
}
|
||||
|
||||
return ticketPlaybook, nil
|
||||
}
|
||||
|
||||
@@ -84,7 +91,7 @@ func (db *Database) PlaybookCreate(ctx context.Context, playbook *model.Playbook
|
||||
var doc model.PlaybookTemplate
|
||||
newctx := driver.WithReturnNew(ctx, &doc)
|
||||
|
||||
meta, err := db.playbookCollection.CreateDocument(ctx, newctx, strcase.ToKebab(playbookYAML.Name), p)
|
||||
meta, err := db.playbookCollection.CreateDocument(ctx, newctx, strcase.ToKebab(playbookYAML.Name), &p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -104,6 +111,7 @@ func (db *Database) PlaybookGet(ctx context.Context, id string) (*model.Playbook
|
||||
|
||||
func (db *Database) PlaybookDelete(ctx context.Context, id string) error {
|
||||
_, err := db.playbookCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -121,7 +129,7 @@ func (db *Database) PlaybookUpdate(ctx context.Context, id string, playbook *mod
|
||||
var doc model.PlaybookTemplate
|
||||
ctx = driver.WithReturnNew(ctx, &doc)
|
||||
|
||||
meta, err := db.playbookCollection.ReplaceDocument(ctx, id, model.PlaybookTemplate{Name: pb.Name, Yaml: playbook.Yaml})
|
||||
meta, err := db.playbookCollection.ReplaceDocument(ctx, id, &model.PlaybookTemplate{Name: pb.Name, Yaml: playbook.Yaml})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -131,7 +139,7 @@ func (db *Database) PlaybookUpdate(ctx context.Context, id string, playbook *mod
|
||||
|
||||
func (db *Database) PlaybookList(ctx context.Context) ([]*model.PlaybookTemplateResponse, error) {
|
||||
query := "FOR d IN @@collection RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": PlaybookCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": PlaybookCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ func playbookGraph(playbook *model.Playbook) (*dag.Graph, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -109,6 +110,7 @@ func active(playbook *model.Playbook, taskID string, d *dag.Graph, task *model.T
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -129,10 +131,11 @@ func active(playbook *model.Playbook, taskID string, d *dag.Graph, task *model.T
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func evalRequirement(aql string, data interface{}) (bool, error) {
|
||||
func evalRequirement(aql string, data any) (bool, error) {
|
||||
if aql == "" {
|
||||
return true, nil
|
||||
}
|
||||
@@ -143,9 +146,9 @@ func evalRequirement(aql string, data interface{}) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var dataMap map[string]interface{}
|
||||
var dataMap map[string]any
|
||||
if data != nil {
|
||||
if dataMapX, ok := data.(map[string]interface{}); ok {
|
||||
if dataMapX, ok := data.(map[string]any); ok {
|
||||
dataMap = dataMapX
|
||||
} else {
|
||||
log.Println("wrong data type for task data")
|
||||
@@ -160,6 +163,7 @@ func evalRequirement(aql string, data interface{}) (bool, error) {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ var playbook2 = &model.Playbook{
|
||||
Name: "Phishing",
|
||||
Tasks: map[string]*model.Task{
|
||||
"board": {Next: map[string]string{
|
||||
"escalate": "boardInvolved == true",
|
||||
"aquire-mail": "boardInvolved == false",
|
||||
"escalate": "boardInvolved == true",
|
||||
"acquire-mail": "boardInvolved == false",
|
||||
}},
|
||||
"escalate": {},
|
||||
"aquire-mail": {Next: map[string]string{
|
||||
"acquire-mail": {Next: map[string]string{
|
||||
"extract-iocs": "schemaKey == 'yes'",
|
||||
"block-sender": "schemaKey == 'yes'",
|
||||
"search-email-gateway": "schemaKey == 'no'",
|
||||
@@ -34,11 +34,11 @@ var playbook3 = &model.Playbook{
|
||||
Name: "Phishing",
|
||||
Tasks: map[string]*model.Task{
|
||||
"board": {Next: map[string]string{
|
||||
"escalate": "boardInvolved == true",
|
||||
"aquire-mail": "boardInvolved == false",
|
||||
}, Data: map[string]interface{}{"boardInvolved": true}, Done: true},
|
||||
"escalate": "boardInvolved == true",
|
||||
"acquire-mail": "boardInvolved == false",
|
||||
}, Data: map[string]any{"boardInvolved": true}, Done: true},
|
||||
"escalate": {},
|
||||
"aquire-mail": {Next: map[string]string{
|
||||
"acquire-mail": {Next: map[string]string{
|
||||
"extract-iocs": "schemaKey == 'yes'",
|
||||
"block-sender": "schemaKey == 'yes'",
|
||||
"search-email-gateway": "schemaKey == 'no'",
|
||||
@@ -71,6 +71,8 @@ var playbook4 = &model.Playbook{
|
||||
}
|
||||
|
||||
func Test_canBeCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
playbook *model.Playbook
|
||||
taskID string
|
||||
@@ -83,18 +85,22 @@ func Test_canBeCompleted(t *testing.T) {
|
||||
}{
|
||||
{"playbook2 board", args{playbook: playbook2, taskID: "board"}, true, false},
|
||||
{"playbook2 escalate", args{playbook: playbook2, taskID: "escalate"}, false, false},
|
||||
{"playbook2 aquire-mail", args{playbook: playbook2, taskID: "aquire-mail"}, false, false},
|
||||
{"playbook2 acquire-mail", args{playbook: playbook2, taskID: "acquire-mail"}, false, false},
|
||||
{"playbook2 block-ioc", args{playbook: playbook2, taskID: "block-ioc"}, false, false},
|
||||
{"playbook3 board", args{playbook: playbook3, taskID: "board"}, false, false},
|
||||
{"playbook3 escalate", args{playbook: playbook3, taskID: "escalate"}, true, false},
|
||||
{"playbook3 aquire-mail", args{playbook: playbook3, taskID: "aquire-mail"}, false, false},
|
||||
{"playbook3 acquire-mail", args{playbook: playbook3, taskID: "acquire-mail"}, false, false},
|
||||
{"playbook3 block-ioc", args{playbook: playbook3, taskID: "block-ioc"}, false, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := activePlaybook(tt.args.playbook, tt.args.taskID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("activePlaybook() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
@@ -105,6 +111,8 @@ func Test_canBeCompleted(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_playbookOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
playbook *model.Playbook
|
||||
}
|
||||
@@ -117,10 +125,14 @@ func Test_playbookOrder(t *testing.T) {
|
||||
{"playbook4", args{playbook: playbook4}, []string{"file-or-hash", "enter-hash", "upload", "hash", "virustotal"}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := toPlaybookResponse(tt.args.playbook)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("activePlaybook() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,13 @@ func (db *Database) RelatedCreate(ctx context.Context, id, id2 int64) error {
|
||||
From: driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id))),
|
||||
To: driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))),
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) RelatedBatchCreate(ctx context.Context, edges []*driver.EdgeDocument) error {
|
||||
_, err := db.relatedCollection.CreateEdges(ctx, edges)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -33,7 +35,7 @@ func (db *Database) RelatedRemove(ctx context.Context, id, id2 int64) error {
|
||||
FOR d in @@collection
|
||||
FILTER (d._from == @id && d._to == @id2) || (d._to == @id && d._from == @id2)
|
||||
REMOVE d in @@collection`
|
||||
_, _, err := db.Query(ctx, q, map[string]interface{}{
|
||||
_, _, err := db.Query(ctx, q, map[string]any{
|
||||
"@collection": RelatedTicketsCollectionName,
|
||||
"id": driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id))),
|
||||
"id2": driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))),
|
||||
@@ -44,5 +46,6 @@ func (db *Database) RelatedRemove(ctx context.Context, id, id2 int64) error {
|
||||
driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))),
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,12 +44,12 @@ func (db *Database) Statistics(ctx context.Context) (*model.Statistics, error) {
|
||||
return &statistics, nil
|
||||
}
|
||||
|
||||
func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *string) (map[string]interface{}, error) {
|
||||
func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *string) (map[string]any, error) {
|
||||
parser := &caql.Parser{Searcher: db.Index, Prefix: "d."}
|
||||
|
||||
queryTree, err := parser.Parse(aggregation)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid aggregation query (%s): syntax error\n", aggregation)
|
||||
return nil, fmt.Errorf("invalid aggregation query (%s): syntax error", aggregation)
|
||||
}
|
||||
aggregationString, err := queryTree.String()
|
||||
if err != nil {
|
||||
@@ -61,7 +61,7 @@ func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *
|
||||
if filter != nil && *filter != "" {
|
||||
queryTree, err := parser.Parse(*filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid filter query (%s): syntax error\n", *filter)
|
||||
return nil, fmt.Errorf("invalid filter query (%s): syntax error", *filter)
|
||||
}
|
||||
filterString, err := queryTree.String()
|
||||
if err != nil {
|
||||
@@ -82,7 +82,7 @@ func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *
|
||||
}
|
||||
defer cur.Close()
|
||||
|
||||
statistics := map[string]interface{}{}
|
||||
statistics := map[string]any{}
|
||||
if _, err := cur.ReadDocument(ctx, &statistics); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
)
|
||||
|
||||
type playbookResponse struct {
|
||||
PlaybookId string `json:"playbook_id"`
|
||||
PlaybookID string `json:"playbook_id"`
|
||||
PlaybookName string `json:"playbook_name"`
|
||||
Playbook model.Playbook `json:"playbook"`
|
||||
TicketId int64 `json:"ticket_id"`
|
||||
TicketID int64 `json:"ticket_id"`
|
||||
TicketName string `json:"ticket_name"`
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func (db *Database) TaskList(ctx context.Context) ([]*model.TaskWithContext, err
|
||||
FILTER d.status == 'open'
|
||||
FOR playbook IN NOT_NULL(VALUES(d.playbooks), [])
|
||||
RETURN { ticket_id: TO_NUMBER(d._key), ticket_name: d.name, playbook_id: POSITION(d.playbooks, playbook, true), playbook_name: playbook.name, playbook: playbook }`
|
||||
cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]interface{}{
|
||||
cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]any{
|
||||
"@collection": TicketCollectionName,
|
||||
}), busdb.ReadOperation)
|
||||
if err != nil {
|
||||
@@ -53,10 +53,10 @@ func (db *Database) TaskList(ctx context.Context) ([]*model.TaskWithContext, err
|
||||
for _, task := range playbook.Tasks {
|
||||
if task.Active {
|
||||
docs = append(docs, &model.TaskWithContext{
|
||||
PlaybookId: doc.PlaybookId,
|
||||
PlaybookId: doc.PlaybookID,
|
||||
PlaybookName: doc.PlaybookName,
|
||||
Task: task,
|
||||
TicketId: doc.TicketId,
|
||||
TicketId: doc.TicketID,
|
||||
TicketName: doc.TicketName,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -62,12 +62,13 @@ func (db *Database) TemplateUpdate(ctx context.Context, id string, template *mod
|
||||
|
||||
func (db *Database) TemplateDelete(ctx context.Context, id string) error {
|
||||
_, err := db.templateCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) TemplateList(ctx context.Context) ([]*model.TicketTemplateResponse, error) {
|
||||
query := "FOR d IN @@collection RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": TemplateCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": TemplateCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,16 +10,20 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/test"
|
||||
)
|
||||
|
||||
var template1 = &model.TicketTemplateForm{
|
||||
Schema: migrations.DefaultTemplateSchema,
|
||||
Name: "Template 1",
|
||||
}
|
||||
var default1 = &model.TicketTemplateForm{
|
||||
Schema: migrations.DefaultTemplateSchema,
|
||||
Name: "Default",
|
||||
}
|
||||
var (
|
||||
template1 = &model.TicketTemplateForm{
|
||||
Schema: migrations.DefaultTemplateSchema,
|
||||
Name: "Template 1",
|
||||
}
|
||||
default1 = &model.TicketTemplateForm{
|
||||
Schema: migrations.DefaultTemplateSchema,
|
||||
Name: "Default",
|
||||
}
|
||||
)
|
||||
|
||||
func TestDatabase_TemplateCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
template *model.TicketTemplateForm
|
||||
}
|
||||
@@ -35,7 +39,10 @@ func TestDatabase_TemplateCreate(t *testing.T) {
|
||||
{name: "Only name", args: args{template: &model.TicketTemplateForm{Name: "name"}}, wantErr: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -50,6 +57,8 @@ func TestDatabase_TemplateCreate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_TemplateDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
@@ -62,7 +71,10 @@ func TestDatabase_TemplateDelete(t *testing.T) {
|
||||
{name: "Not existing", args: args{"foobar"}, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -81,6 +93,8 @@ func TestDatabase_TemplateDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_TemplateGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
@@ -94,7 +108,10 @@ func TestDatabase_TemplateGet(t *testing.T) {
|
||||
{name: "Not existing", args: args{id: "foobar"}, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -108,6 +125,7 @@ func TestDatabase_TemplateGet(t *testing.T) {
|
||||
got, err := db.TemplateGet(test.Context(), tt.args.id)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("TemplateGet() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -120,6 +138,8 @@ func TestDatabase_TemplateGet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_TemplateList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want []*model.TicketTemplateResponse
|
||||
@@ -128,7 +148,10 @@ func TestDatabase_TemplateList(t *testing.T) {
|
||||
{name: "Normal", want: []*model.TicketTemplateResponse{{ID: "default", Name: "Default", Schema: migrations.DefaultTemplateSchema}, {ID: "template-1", Name: template1.Name, Schema: template1.Schema}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -142,6 +165,7 @@ func TestDatabase_TemplateList(t *testing.T) {
|
||||
got, err := db.TemplateList(test.Context())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("TemplateList() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
assert.Equal(t, got, tt.want)
|
||||
@@ -150,6 +174,8 @@ func TestDatabase_TemplateList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_TemplateUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
id string
|
||||
template *model.TicketTemplateForm
|
||||
@@ -163,7 +189,10 @@ func TestDatabase_TemplateUpdate(t *testing.T) {
|
||||
{name: "Not existing", args: args{"foobar", template1}, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/index"
|
||||
)
|
||||
|
||||
func toTicket(ticketForm *model.TicketForm) (interface{}, error) {
|
||||
func toTicket(ticketForm *model.TicketForm) (any, error) {
|
||||
playbooks, err := toPlaybooks(ticketForm.Playbooks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,8 +66,9 @@ func toTicket(ticketForm *model.TicketForm) (interface{}, error) {
|
||||
ticket.Status = "open"
|
||||
}
|
||||
if ticketForm.ID != nil {
|
||||
return &busdb.Keyed{Key: strconv.FormatInt(*ticketForm.ID, 10), Doc: ticket}, nil
|
||||
return &busdb.Keyed[model.Ticket]{Key: strconv.FormatInt(*ticketForm.ID, 10), Doc: ticket}, nil
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
@@ -79,6 +81,7 @@ func toTicketResponses(tickets []*model.TicketSimpleResponse) ([]*model.TicketRe
|
||||
}
|
||||
extendedTickets = append(extendedTickets, tr)
|
||||
}
|
||||
|
||||
return extendedTickets, nil
|
||||
}
|
||||
|
||||
@@ -167,6 +170,7 @@ func toPlaybookResponses(playbooks map[string]*model.Playbook) (map[string]*mode
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
@@ -195,6 +199,7 @@ func toPlaybookResponse(playbook *model.Playbook) (*model.PlaybookResponse, erro
|
||||
re.Tasks[taskID] = rootTask
|
||||
i++
|
||||
}
|
||||
|
||||
return re, nil
|
||||
}
|
||||
|
||||
@@ -204,7 +209,7 @@ func (db *Database) TicketBatchCreate(ctx context.Context, ticketForms []*model.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dbTickets []interface{}
|
||||
var dbTickets []any
|
||||
for _, ticketForm := range ticketForms {
|
||||
ticket, err := toTicket(ticketForm)
|
||||
if err != nil {
|
||||
@@ -231,7 +236,7 @@ func (db *Database) TicketBatchCreate(ctx context.Context, ticketForms []*model.
|
||||
LET noiddoc = UNSET(keyeddoc, "id")
|
||||
INSERT noiddoc INTO @@collection
|
||||
RETURN NEW`
|
||||
apiTickets, _, err := db.ticketListQuery(ctx, query, mergeMaps(map[string]interface{}{
|
||||
apiTickets, _, err := db.ticketListQuery(ctx, query, mergeMaps(map[string]any{
|
||||
"tickets": dbTickets,
|
||||
}, ticketFilterVars), busdb.CreateOperation)
|
||||
if err != nil {
|
||||
@@ -247,7 +252,11 @@ func (db *Database) TicketBatchCreate(ctx context.Context, ticketForms []*model.
|
||||
ids = append(ids, driver.NewDocumentID(TicketCollectionName, fmt.Sprint(apiTicket.ID)))
|
||||
}
|
||||
|
||||
go db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryUpdated)
|
||||
go func() {
|
||||
if err := db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryUpdated); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticketResponses, err := toTicketResponses(apiTickets)
|
||||
if err != nil {
|
||||
@@ -294,6 +303,7 @@ func batchIndex(index *index.Index, tickets []*model.TicketSimpleResponse) error
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -306,9 +316,9 @@ func (db *Database) TicketGet(ctx context.Context, ticketID int64) (*model.Ticke
|
||||
return db.ticketGetQuery(ctx, ticketID, `LET d = DOCUMENT(@@collection, @ID) `+ticketFilterQuery+` RETURN d`, ticketFilterVars, busdb.ReadOperation)
|
||||
}
|
||||
|
||||
func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query string, bindVars map[string]interface{}, operation *busdb.Operation) (*model.TicketWithTickets, error) {
|
||||
func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query string, bindVars map[string]any, operation *busdb.Operation) (*model.TicketWithTickets, error) {
|
||||
if bindVars == nil {
|
||||
bindVars = map[string]interface{}{}
|
||||
bindVars = map[string]any{}
|
||||
}
|
||||
bindVars["@collection"] = TicketCollectionName
|
||||
if ticketID != 0 {
|
||||
@@ -350,7 +360,7 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
|
||||
` + ticketFilterQuery + `
|
||||
RETURN d`
|
||||
|
||||
outTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]interface{}{
|
||||
outTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]any{
|
||||
"ID": fmt.Sprint(ticketID),
|
||||
"graph": TicketArtifactsGraphName,
|
||||
"@tickets": TicketCollectionName,
|
||||
@@ -368,7 +378,7 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
|
||||
` + ticketFilterQuery + `
|
||||
RETURN d`
|
||||
|
||||
inTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]interface{}{
|
||||
inTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]any{
|
||||
"ID": fmt.Sprint(ticketID),
|
||||
"graph": TicketArtifactsGraphName,
|
||||
"@tickets": TicketCollectionName,
|
||||
@@ -387,7 +397,7 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
|
||||
FOR a IN NOT_NULL(d.artifacts, [])
|
||||
FILTER POSITION(@artifacts, a.name)
|
||||
RETURN d`
|
||||
sameArtifactTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]interface{}{
|
||||
sameArtifactTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]any{
|
||||
"ID": fmt.Sprint(ticketID),
|
||||
"artifacts": artifactNames,
|
||||
}, ticketFilterVars), busdb.ReadOperation)
|
||||
@@ -395,7 +405,8 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tickets := append(outTickets, inTickets...)
|
||||
tickets := outTickets
|
||||
tickets = append(tickets, inTickets...)
|
||||
tickets = append(tickets, sameArtifactTickets...)
|
||||
sort.Slice(tickets, func(i, j int) bool {
|
||||
return tickets[i].ID < tickets[j].ID
|
||||
@@ -425,7 +436,8 @@ func (db *Database) TicketUpdate(ctx context.Context, ticketID int64, ticket *mo
|
||||
REPLACE d WITH @ticket IN @@collection
|
||||
RETURN NEW`
|
||||
ticket.Modified = time.Now().UTC() // TODO make setable?
|
||||
return db.ticketGetQuery(ctx, ticketID, query, mergeMaps(map[string]interface{}{"ticket": ticket}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, ticketID, query, mergeMaps(map[string]any{"ticket": ticket}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated, Ids: []driver.DocumentID{
|
||||
driver.NewDocumentID(TicketCollectionName, strconv.FormatInt(ticketID, 10)),
|
||||
},
|
||||
@@ -447,15 +459,15 @@ func (db *Database) TicketDelete(ctx context.Context, ticketID int64) error {
|
||||
}
|
||||
|
||||
func (db *Database) TicketList(ctx context.Context, ticketType string, query string, sorts []string, desc []bool, offset, count int64) (*model.TicketList, error) {
|
||||
binVars := map[string]interface{}{}
|
||||
binVars := map[string]any{}
|
||||
|
||||
var typeString = ""
|
||||
typeString := ""
|
||||
if ticketType != "" {
|
||||
typeString = "FILTER d.type == @type "
|
||||
binVars["type"] = ticketType
|
||||
}
|
||||
|
||||
var filterString = ""
|
||||
filterString := ""
|
||||
if query != "" {
|
||||
parser := &caql.Parser{Searcher: db.Index, Prefix: "d."}
|
||||
queryTree, err := parser.Parse(query)
|
||||
@@ -493,6 +505,7 @@ func (db *Database) TicketList(ctx context.Context, ticketType string, query str
|
||||
RETURN d`
|
||||
// RETURN KEEP(d, "_key", "id", "name", "type", "created")`
|
||||
ticketList, _, err := db.ticketListQuery(ctx, q, mergeMaps(binVars, ticketFilterVars), busdb.ReadOperation)
|
||||
|
||||
return &model.TicketList{
|
||||
Count: documentCount,
|
||||
Tickets: ticketList,
|
||||
@@ -500,9 +513,9 @@ func (db *Database) TicketList(ctx context.Context, ticketType string, query str
|
||||
// return map[string]interface{}{"tickets": ticketList, "count": documentCount}, err
|
||||
}
|
||||
|
||||
func (db *Database) ticketListQuery(ctx context.Context, query string, bindVars map[string]interface{}, operation *busdb.Operation) ([]*model.TicketSimpleResponse, *model.LogEntry, error) {
|
||||
func (db *Database) ticketListQuery(ctx context.Context, query string, bindVars map[string]any, operation *busdb.Operation) ([]*model.TicketSimpleResponse, *model.LogEntry, error) {
|
||||
if bindVars == nil {
|
||||
bindVars = map[string]interface{}{}
|
||||
bindVars = map[string]any{}
|
||||
}
|
||||
bindVars["@collection"] = TicketCollectionName
|
||||
|
||||
@@ -533,9 +546,9 @@ func (db *Database) ticketListQuery(ctx context.Context, query string, bindVars
|
||||
return docs, logEntry, nil
|
||||
}
|
||||
|
||||
func (db *Database) TicketCount(ctx context.Context, typequery, filterquery string, bindVars map[string]interface{}) (int, error) {
|
||||
func (db *Database) TicketCount(ctx context.Context, typequery, filterquery string, bindVars map[string]any) (int, error) {
|
||||
if bindVars == nil {
|
||||
bindVars = map[string]interface{}{}
|
||||
bindVars = map[string]any{}
|
||||
}
|
||||
bindVars["@collection"] = TicketCollectionName
|
||||
|
||||
@@ -555,10 +568,11 @@ func (db *Database) TicketCount(ctx context.Context, typequery, filterquery stri
|
||||
return 0, err
|
||||
}
|
||||
cursor.Close()
|
||||
|
||||
return documentCount, nil
|
||||
}
|
||||
|
||||
func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]interface{}) string {
|
||||
func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]any) string {
|
||||
sort := ""
|
||||
if len(paramsSort) > 0 {
|
||||
var sorts []string
|
||||
@@ -572,21 +586,23 @@ func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]inter
|
||||
}
|
||||
sort = "SORT " + strings.Join(sorts, ", ")
|
||||
}
|
||||
|
||||
return sort
|
||||
}
|
||||
|
||||
func mergeMaps(a map[string]interface{}, b map[string]interface{}) map[string]interface{} {
|
||||
merged := map[string]interface{}{}
|
||||
func mergeMaps(a map[string]any, b map[string]any) map[string]any {
|
||||
merged := map[string]any{}
|
||||
for k, v := range a {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range b {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func validate(e interface{}, schema *gojsonschema.Schema) error {
|
||||
func validate(e any, schema *gojsonschema.Schema) error {
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -602,7 +618,9 @@ func validate(e interface{}, schema *gojsonschema.Schema) error {
|
||||
for _, e := range res.Errors() {
|
||||
l = append(l, e.String())
|
||||
}
|
||||
|
||||
return fmt.Errorf("validation failed: %v", strings.Join(l, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -34,7 +34,8 @@ func (db *Database) AddArtifact(ctx context.Context, id int64, artifact *model.A
|
||||
` + ticketFilterQuery + `
|
||||
UPDATE d WITH { "modified": @now, "artifacts": PUSH(NOT_NULL(d.artifacts, []), @artifact) } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"artifact": artifact, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"artifact": artifact, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -57,6 +58,7 @@ func inferType(name string) string {
|
||||
case commonregex.SHA256HexRegex.MatchString(name):
|
||||
return "sha256"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@@ -73,7 +75,8 @@ func (db *Database) RemoveArtifact(ctx context.Context, id int64, name string) (
|
||||
LET newartifacts = REMOVE_VALUE(d.artifacts, a)
|
||||
UPDATE d WITH { "modified": @now, "artifacts": newartifacts } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"name": name, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"name": name, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -91,7 +94,8 @@ func (db *Database) SetTemplate(ctx context.Context, id int64, schema string) (*
|
||||
` + ticketFilterQuery + `
|
||||
UPDATE d WITH { "schema": @schema } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"schema": schema}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"schema": schema}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -122,7 +126,8 @@ func (db *Database) AddComment(ctx context.Context, id int64, comment *model.Com
|
||||
` + ticketFilterQuery + `
|
||||
UPDATE d WITH { "modified": @now, "comments": PUSH(NOT_NULL(d.comments, []), @comment) } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"comment": comment, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"comment": comment, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -140,7 +145,8 @@ func (db *Database) RemoveComment(ctx context.Context, id int64, commentID int64
|
||||
` + ticketFilterQuery + `
|
||||
UPDATE d WITH { "modified": @now, "comments": REMOVE_NTH(d.comments, @commentID) } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"commentID": commentID, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"commentID": commentID, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -158,7 +164,8 @@ func (db *Database) SetReferences(ctx context.Context, id int64, references []*m
|
||||
` + ticketFilterQuery + `
|
||||
UPDATE d WITH { "modified": @now, "references": @references } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"references": references, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"references": references, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -176,7 +183,8 @@ func (db *Database) AddFile(ctx context.Context, id int64, file *model.File) (*m
|
||||
` + ticketFilterQuery + `
|
||||
UPDATE d WITH { "modified": @now, "files": APPEND(NOT_NULL(d.files, []), [@file]) } IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"file": file, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"file": file, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
|
||||
Type: bus.DatabaseEntryUpdated,
|
||||
Ids: []driver.DocumentID{
|
||||
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
|
||||
@@ -213,7 +221,7 @@ func (db *Database) AddTicketPlaybook(ctx context.Context, id int64, playbookTem
|
||||
LET newticket = MERGE(d, { "modified": @now, "playbooks": newplaybooks })
|
||||
REPLACE d WITH newticket IN @@collection
|
||||
RETURN NEW`
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"playbook": pb,
|
||||
"playbookID": findName(parentTicket.Playbooks, playbookID),
|
||||
"now": time.Now().UTC(),
|
||||
@@ -258,6 +266,7 @@ func runRootTask(ticket *model.TicketResponse, playbookID string, db *Database)
|
||||
}
|
||||
|
||||
runNextTasks(ticket.ID, playbookID, root.Next, root.Data, ticket, db)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -273,7 +282,8 @@ func (db *Database) RemoveTicketPlaybook(ctx context.Context, id int64, playbook
|
||||
LET newplaybooks = UNSET(d.playbooks, @playbookID)
|
||||
REPLACE d WITH MERGE(d, { "modified": @now, "playbooks": newplaybooks }) IN @@collection
|
||||
RETURN NEW`
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
|
||||
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"playbookID": playbookID,
|
||||
"now": time.Now().UTC(),
|
||||
}, ticketFilterVars), &busdb.Operation{
|
||||
|
||||
@@ -41,7 +41,7 @@ func (db *Database) TaskGet(ctx context.Context, id int64, playbookID string, ta
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *Database) TaskComplete(ctx context.Context, id int64, playbookID string, taskID string, data interface{}) (*model.TicketWithTickets, error) {
|
||||
func (db *Database) TaskComplete(ctx context.Context, id int64, playbookID string, taskID string, data any) (*model.TicketWithTickets, error) {
|
||||
inc, err := db.TicketGet(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -68,7 +68,7 @@ func (db *Database) TaskComplete(ctx context.Context, id int64, playbookID strin
|
||||
|
||||
UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection
|
||||
RETURN NEW`
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"playbookID": playbookID,
|
||||
"taskID": taskID,
|
||||
"data": data,
|
||||
@@ -130,7 +130,7 @@ func (db *Database) TaskUpdateOwner(ctx context.Context, id int64, playbookID st
|
||||
|
||||
UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection
|
||||
RETURN NEW`
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"playbookID": playbookID,
|
||||
"taskID": taskID,
|
||||
"owner": owner,
|
||||
@@ -148,7 +148,7 @@ func (db *Database) TaskUpdateOwner(ctx context.Context, id int64, playbookID st
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
func (db *Database) TaskUpdateData(ctx context.Context, id int64, playbookID string, taskID string, data map[string]interface{}) (*model.TicketWithTickets, error) {
|
||||
func (db *Database) TaskUpdateData(ctx context.Context, id int64, playbookID string, taskID string, data map[string]any) (*model.TicketWithTickets, error) {
|
||||
ticketFilterQuery, ticketFilterVars, err := db.Hooks.TicketWriteFilter(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -165,7 +165,7 @@ func (db *Database) TaskUpdateData(ctx context.Context, id int64, playbookID str
|
||||
|
||||
UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection
|
||||
RETURN NEW`
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
|
||||
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
|
||||
"playbookID": playbookID,
|
||||
"taskID": taskID,
|
||||
"data": data,
|
||||
@@ -198,7 +198,7 @@ func (db *Database) TaskRun(ctx context.Context, id int64, playbookID string, ta
|
||||
return nil
|
||||
}
|
||||
|
||||
func runNextTasks(id int64, playbookID string, next map[string]string, data interface{}, ticket *model.TicketResponse, db *Database) {
|
||||
func runNextTasks(id int64, playbookID string, next map[string]string, data any, ticket *model.TicketResponse, db *Database) {
|
||||
for nextTaskID, requirement := range next {
|
||||
nextTask := ticket.Playbooks[playbookID].Tasks[nextTaskID]
|
||||
if nextTask.Type == model.TaskTypeAutomation {
|
||||
@@ -220,5 +220,6 @@ func runTask(ticketID int64, playbookID string, taskID string, task *model.TaskR
|
||||
msgContext := &model.Context{Playbook: playbook, Task: task, Ticket: ticket}
|
||||
origin := &model.Origin{TaskOrigin: &model.TaskOrigin{TaskId: taskID, PlaybookId: playbookID, TicketId: ticketID}}
|
||||
jobID := uuid.NewString()
|
||||
|
||||
return publishJobMapping(jobID, *task.Automation, msgContext, origin, task.Payload, db)
|
||||
}
|
||||
|
||||
@@ -75,12 +75,13 @@ func (db *Database) TicketTypeUpdate(ctx context.Context, id string, tickettype
|
||||
|
||||
func (db *Database) TicketTypeDelete(ctx context.Context, id string) error {
|
||||
_, err := db.tickettypeCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) TicketTypeList(ctx context.Context) ([]*model.TicketTypeResponse, error) {
|
||||
query := "FOR d IN @@collection RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": TicketTypeCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": TicketTypeCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ func generateKey() string {
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
@@ -78,8 +79,10 @@ func (db *Database) UserGetOrCreate(ctx context.Context, newUser *model.UserForm
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.UserResponse{ID: newUser.ID, Roles: newUser.Roles, Blocked: newUser.Blocked}, nil
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -132,12 +135,13 @@ func (db *Database) UserGet(ctx context.Context, id string) (*model.UserResponse
|
||||
|
||||
func (db *Database) UserDelete(ctx context.Context, id string) error {
|
||||
_, err := db.userCollection.RemoveDocument(ctx, id)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) UserList(ctx context.Context) ([]*model.UserResponse, error) {
|
||||
query := "FOR d IN @@collection RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": UserCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": UserCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -163,7 +167,7 @@ func (db *Database) UserByHash(ctx context.Context, sha256 string) (*model.UserR
|
||||
FILTER d.sha256 == @sha256
|
||||
RETURN d`
|
||||
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": UserCollectionName, "sha256": sha256}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": UserCollectionName, "sha256": sha256}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ func (db *Database) UserDataCreate(ctx context.Context, id string, userdata *mod
|
||||
}
|
||||
|
||||
_, err := db.userdataCollection.CreateDocument(ctx, ctx, id, userdata)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -37,6 +38,7 @@ func (db *Database) UserDataGetOrCreate(ctx context.Context, id string, newUserD
|
||||
if err != nil {
|
||||
return toUserDataResponse(id, newUserData), db.UserDataCreate(ctx, id, newUserData)
|
||||
}
|
||||
|
||||
return setting, nil
|
||||
}
|
||||
|
||||
@@ -52,7 +54,7 @@ func (db *Database) UserDataGet(ctx context.Context, id string) (*model.UserData
|
||||
|
||||
func (db *Database) UserDataList(ctx context.Context) ([]*model.UserDataResponse, error) {
|
||||
query := "FOR d IN @@collection SORT d.username ASC RETURN d"
|
||||
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": UserDataCollectionName}, busdb.ReadOperation)
|
||||
cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": UserDataCollectionName}, busdb.ReadOperation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ var bobResponse = &model.UserDataResponse{
|
||||
}
|
||||
|
||||
func TestDatabase_UserDataCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
id string
|
||||
setting *model.UserData
|
||||
@@ -37,7 +39,10 @@ func TestDatabase_UserDataCreate(t *testing.T) {
|
||||
{name: "Only settingname", args: args{id: "bob"}, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -52,6 +57,8 @@ func TestDatabase_UserDataCreate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_UserDataGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
@@ -65,7 +72,10 @@ func TestDatabase_UserDataGet(t *testing.T) {
|
||||
{name: "Not existing", args: args{id: "foo"}, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -79,6 +89,7 @@ func TestDatabase_UserDataGet(t *testing.T) {
|
||||
got, err := db.UserDataGet(test.Context(), tt.args.id)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserDataGet() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -91,6 +102,8 @@ func TestDatabase_UserDataGet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_UserDataList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want []*model.UserDataResponse
|
||||
@@ -99,7 +112,10 @@ func TestDatabase_UserDataList(t *testing.T) {
|
||||
{name: "Normal list", want: []*model.UserDataResponse{bobResponse}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -113,6 +129,7 @@ func TestDatabase_UserDataList(t *testing.T) {
|
||||
got, err := db.UserDataList(test.Context())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserDataList() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -122,6 +139,8 @@ func TestDatabase_UserDataList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDatabase_UserDataUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
id string
|
||||
setting *model.UserData
|
||||
@@ -135,7 +154,10 @@ func TestDatabase_UserDataUpdate(t *testing.T) {
|
||||
{name: "Not existing", args: args{id: "foo"}, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, db, cleanup, err := test.DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
13
file.go
13
file.go
@@ -30,11 +30,13 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
|
||||
ticketID := chi.URLParam(r, "ticketID")
|
||||
if ticketID == "" {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := storage.CreateBucket(client, ticketID); err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create bucket: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -50,6 +52,7 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
|
||||
})
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create tusd handler: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -73,6 +76,7 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
|
||||
doc, err := db.AddFile(ctx, id, file)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -92,7 +96,6 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
|
||||
default:
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("unknown method"))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,18 +104,21 @@ func upload(db *database.Database, client *s3.S3, uploader *s3manager.Uploader)
|
||||
ticketID := chi.URLParam(r, "ticketID")
|
||||
if ticketID == "" {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err := storage.CreateBucket(client, ticketID); err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create bucket: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,12 +129,14 @@ func upload(db *database.Database, client *s3.S3, uploader *s3manager.Uploader)
|
||||
})
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(ticketID, 10, 64)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -138,6 +146,7 @@ func upload(db *database.Database, client *s3.S3, uploader *s3manager.Uploader)
|
||||
})
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -148,12 +157,14 @@ func download(downloader *s3manager.Downloader) http.HandlerFunc {
|
||||
ticketID := chi.URLParam(r, "ticketID")
|
||||
if ticketID == "" {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
key := chi.URLParam(r, "key")
|
||||
if key == "" {
|
||||
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("key not given"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ func parseQueryOptionalBoolArray(r *http.Request, key string) ([]bool, error) {
|
||||
return parseQueryBoolArray(r, key)
|
||||
}
|
||||
|
||||
func parseBody(b []byte, i interface{}) error {
|
||||
func parseBody(b []byte, i any) error {
|
||||
dec := json.NewDecoder(bytes.NewBuffer(b))
|
||||
err := dec.Decode(i)
|
||||
if err != nil {
|
||||
@@ -137,7 +137,7 @@ func JSONErrorStatus(w http.ResponseWriter, status int, err error) {
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func response(w http.ResponseWriter, v interface{}, err error) {
|
||||
func response(w http.ResponseWriter, v any, err error) {
|
||||
if err != nil {
|
||||
var httpError *HTTPError
|
||||
if errors.As(err, &httpError) {
|
||||
@@ -172,7 +172,7 @@ func validateSchema(body []byte, schema *gojsonschema.Schema, w http.ResponseWri
|
||||
validationErrors = append(validationErrors, valdiationError.String())
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(map[string]interface{}{"error": "wrong input", "errors": validationErrors})
|
||||
b, _ := json.Marshal(map[string]any{"error": "wrong input", "errors": validationErrors})
|
||||
w.Write(b)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Service interface {
|
||||
CurrentUser(context.Context) (*model.UserResponse, error)
|
||||
CurrentUserData(context.Context) (*model.UserDataResponse, error)
|
||||
UpdateCurrentUserData(context.Context, *model.UserData) (*model.UserDataResponse, error)
|
||||
DashboardData(context.Context, string, *string) (map[string]interface{}, error)
|
||||
DashboardData(context.Context, string, *string) (map[string]any, error)
|
||||
ListDashboards(context.Context) ([]*model.DashboardResponse, error)
|
||||
CreateDashboard(context.Context, *model.Dashboard) (*model.DashboardResponse, error)
|
||||
GetDashboard(context.Context, string) (*model.DashboardResponse, error)
|
||||
@@ -60,8 +60,8 @@ type Service interface {
|
||||
RemoveComment(context.Context, int64, int) (*model.TicketWithTickets, error)
|
||||
AddTicketPlaybook(context.Context, int64, *model.PlaybookTemplateForm) (*model.TicketWithTickets, error)
|
||||
RemoveTicketPlaybook(context.Context, int64, string) (*model.TicketWithTickets, error)
|
||||
SetTaskData(context.Context, int64, string, string, map[string]interface{}) (*model.TicketWithTickets, error)
|
||||
CompleteTask(context.Context, int64, string, string, map[string]interface{}) (*model.TicketWithTickets, error)
|
||||
SetTaskData(context.Context, int64, string, string, map[string]any) (*model.TicketWithTickets, error)
|
||||
CompleteTask(context.Context, int64, string, string, map[string]any) (*model.TicketWithTickets, error)
|
||||
SetTaskOwner(context.Context, int64, string, string, string) (*model.TicketWithTickets, error)
|
||||
RunTask(context.Context, int64, string, string) error
|
||||
SetReferences(context.Context, int64, *model.ReferenceArray) (*model.TicketWithTickets, error)
|
||||
@@ -901,7 +901,7 @@ func (s *server) setTaskDataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var dataP map[string]interface{}
|
||||
var dataP map[string]any
|
||||
if err := parseBody(body, &dataP); err != nil {
|
||||
JSONError(w, err)
|
||||
return
|
||||
@@ -928,7 +928,7 @@ func (s *server) completeTaskHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var dataP map[string]interface{}
|
||||
var dataP map[string]any
|
||||
if err := parseBody(body, &dataP); err != nil {
|
||||
JSONError(w, err)
|
||||
return
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -251,14 +251,14 @@ type DashboardResponse struct {
|
||||
}
|
||||
|
||||
type Enrichment struct {
|
||||
Created time.Time `json:"created"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Name string `json:"name"`
|
||||
Created time.Time `json:"created"`
|
||||
Data map[string]any `json:"data"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type EnrichmentForm struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
@@ -267,39 +267,39 @@ type File struct {
|
||||
}
|
||||
|
||||
type Job struct {
|
||||
Automation string `json:"automation"`
|
||||
Container *string `json:"container,omitempty"`
|
||||
Log *string `json:"log,omitempty"`
|
||||
Origin *Origin `json:"origin,omitempty"`
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
Payload interface{} `json:"payload,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
Status string `json:"status"`
|
||||
Automation string `json:"automation"`
|
||||
Container *string `json:"container,omitempty"`
|
||||
Log *string `json:"log,omitempty"`
|
||||
Origin *Origin `json:"origin,omitempty"`
|
||||
Output map[string]any `json:"output,omitempty"`
|
||||
Payload any `json:"payload,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type JobForm struct {
|
||||
Automation string `json:"automation"`
|
||||
Origin *Origin `json:"origin,omitempty"`
|
||||
Payload interface{} `json:"payload,omitempty"`
|
||||
Automation string `json:"automation"`
|
||||
Origin *Origin `json:"origin,omitempty"`
|
||||
Payload any `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
type JobResponse struct {
|
||||
Automation string `json:"automation"`
|
||||
Container *string `json:"container,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Log *string `json:"log,omitempty"`
|
||||
Origin *Origin `json:"origin,omitempty"`
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
Payload interface{} `json:"payload,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Automation string `json:"automation"`
|
||||
Container *string `json:"container,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Log *string `json:"log,omitempty"`
|
||||
Origin *Origin `json:"origin,omitempty"`
|
||||
Output map[string]any `json:"output,omitempty"`
|
||||
Payload any `json:"payload,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type JobUpdate struct {
|
||||
Container *string `json:"container,omitempty"`
|
||||
Log *string `json:"log,omitempty"`
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
Status string `json:"status"`
|
||||
Container *string `json:"container,omitempty"`
|
||||
Log *string `json:"log,omitempty"`
|
||||
Output map[string]any `json:"output,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type LogEntry struct {
|
||||
@@ -312,7 +312,7 @@ type LogEntry struct {
|
||||
|
||||
type Message struct {
|
||||
Context *Context `json:"context,omitempty"`
|
||||
Payload interface{} `json:"payload,omitempty"`
|
||||
Payload any `json:"payload,omitempty"`
|
||||
Secrets map[string]string `json:"secrets,omitempty"`
|
||||
}
|
||||
|
||||
@@ -385,18 +385,18 @@ type Statistics struct {
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
Automation *string `json:"automation,omitempty"`
|
||||
Closed *time.Time `json:"closed,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Join *bool `json:"join,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Next map[string]string `json:"next,omitempty"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Payload map[string]string `json:"payload,omitempty"`
|
||||
Schema map[string]interface{} `json:"schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Automation *string `json:"automation,omitempty"`
|
||||
Closed *time.Time `json:"closed,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Join *bool `json:"join,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Next map[string]string `json:"next,omitempty"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Payload map[string]string `json:"payload,omitempty"`
|
||||
Schema map[string]any `json:"schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type TaskOrigin struct {
|
||||
@@ -406,20 +406,20 @@ type TaskOrigin struct {
|
||||
}
|
||||
|
||||
type TaskResponse struct {
|
||||
Active bool `json:"active"`
|
||||
Automation *string `json:"automation,omitempty"`
|
||||
Closed *time.Time `json:"closed,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Join *bool `json:"join,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Next map[string]string `json:"next,omitempty"`
|
||||
Order int64 `json:"order"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Payload map[string]string `json:"payload,omitempty"`
|
||||
Schema map[string]interface{} `json:"schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Active bool `json:"active"`
|
||||
Automation *string `json:"automation,omitempty"`
|
||||
Closed *time.Time `json:"closed,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Join *bool `json:"join,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Next map[string]string `json:"next,omitempty"`
|
||||
Order int64 `json:"order"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Payload map[string]string `json:"payload,omitempty"`
|
||||
Schema map[string]any `json:"schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type TaskWithContext struct {
|
||||
@@ -432,28 +432,28 @@ type TaskWithContext struct {
|
||||
}
|
||||
|
||||
type Ticket struct {
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Name string `json:"name"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Playbooks map[string]*Playbook `json:"playbooks,omitempty"`
|
||||
Read []string `json:"read,omitempty"`
|
||||
References []*Reference `json:"references,omitempty"`
|
||||
Schema string `json:"schema"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Write []string `json:"write,omitempty"`
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Name string `json:"name"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Playbooks map[string]*Playbook `json:"playbooks,omitempty"`
|
||||
Read []string `json:"read,omitempty"`
|
||||
References []*Reference `json:"references,omitempty"`
|
||||
Schema string `json:"schema"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Write []string `json:"write,omitempty"`
|
||||
}
|
||||
|
||||
type TicketForm struct {
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created *time.Time `json:"created,omitempty"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
Modified *time.Time `json:"modified,omitempty"`
|
||||
@@ -479,7 +479,7 @@ type TicketResponse struct {
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Modified time.Time `json:"modified"`
|
||||
@@ -495,22 +495,22 @@ type TicketResponse struct {
|
||||
}
|
||||
|
||||
type TicketSimpleResponse struct {
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Name string `json:"name"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Playbooks map[string]*Playbook `json:"playbooks,omitempty"`
|
||||
Read []string `json:"read,omitempty"`
|
||||
References []*Reference `json:"references,omitempty"`
|
||||
Schema string `json:"schema"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Write []string `json:"write,omitempty"`
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Name string `json:"name"`
|
||||
Owner *string `json:"owner,omitempty"`
|
||||
Playbooks map[string]*Playbook `json:"playbooks,omitempty"`
|
||||
Read []string `json:"read,omitempty"`
|
||||
References []*Reference `json:"references,omitempty"`
|
||||
Schema string `json:"schema"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Write []string `json:"write,omitempty"`
|
||||
}
|
||||
|
||||
type TicketTemplate struct {
|
||||
@@ -560,7 +560,7 @@ type TicketWithTickets struct {
|
||||
Artifacts []*Artifact `json:"artifacts,omitempty"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
Created time.Time `json:"created"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Files []*File `json:"files,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Logs []*LogEntry `json:"logs,omitempty"`
|
||||
|
||||
82
go.mod
82
go.mod
@@ -1,53 +1,93 @@
|
||||
module github.com/SecurityBrewery/catalyst
|
||||
|
||||
go 1.16
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.5.1 // indirect
|
||||
github.com/alecthomas/kong v0.2.17
|
||||
github.com/alecthomas/kong-yaml v0.1.1
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211101200231-0802afb9c160
|
||||
github.com/arangodb/go-driver v1.2.1
|
||||
github.com/aws/aws-sdk-go v1.41.19
|
||||
github.com/bits-and-blooms/bitset v1.2.1 // indirect
|
||||
github.com/blevesearch/bleve/v2 v2.2.2
|
||||
github.com/bmizerany/pat v0.0.0-20210406213842-e4b6760bdd6f // indirect
|
||||
github.com/containerd/containerd v1.5.8 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.1.0
|
||||
github.com/docker/docker v17.12.0-ce-rc1.0.20201201034508-7d75c1d40d88+incompatible
|
||||
github.com/docker/go-connections v0.4.0 // indirect
|
||||
github.com/eclipse/paho.mqtt.golang v1.3.5 // indirect
|
||||
github.com/emitter-io/go/v2 v2.0.9
|
||||
github.com/go-chi/chi v1.5.4
|
||||
github.com/go-chi/cors v1.2.0
|
||||
github.com/gobwas/ws v1.1.0
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/mux v1.8.0 // indirect
|
||||
github.com/iancoleman/strcase v0.2.0
|
||||
github.com/icza/dyno v0.0.0-20210726202311-f1bafe5d9996
|
||||
github.com/imdario/mergo v0.3.12
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/mingrammer/commonregex v1.0.1
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.0 // indirect
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/tidwall/gjson v1.12.1
|
||||
github.com/tidwall/sjson v1.2.4
|
||||
github.com/tus/tusd v1.8.0
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
|
||||
github.com/xeipuuv/gojsonschema v1.2.0
|
||||
go.etcd.io/bbolt v1.3.6 // indirect
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect
|
||||
golang.org/x/net v0.0.0-20211105192438-b53810dc28af // indirect
|
||||
golang.org/x/exp v0.0.0-20220318154914-8dddf5d87bd8
|
||||
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
|
||||
golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.5.1 // indirect
|
||||
github.com/RoaringBitmap/roaring v0.9.4 // indirect
|
||||
github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e // indirect
|
||||
github.com/bits-and-blooms/bitset v1.2.1 // indirect
|
||||
github.com/blevesearch/bleve_index_api v1.0.1 // indirect
|
||||
github.com/blevesearch/go-porterstemmer v1.0.3 // indirect
|
||||
github.com/blevesearch/mmap-go v1.0.3 // indirect
|
||||
github.com/blevesearch/scorch_segment_api/v2 v2.1.0 // indirect
|
||||
github.com/blevesearch/segment v0.9.0 // indirect
|
||||
github.com/blevesearch/snowballstem v0.9.0 // indirect
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.1 // indirect
|
||||
github.com/blevesearch/vellum v1.0.7 // indirect
|
||||
github.com/blevesearch/zapx/v11 v11.3.1 // indirect
|
||||
github.com/blevesearch/zapx/v12 v12.3.1 // indirect
|
||||
github.com/blevesearch/zapx/v13 v13.3.1 // indirect
|
||||
github.com/blevesearch/zapx/v14 v14.3.1 // indirect
|
||||
github.com/blevesearch/zapx/v15 v15.3.1 // indirect
|
||||
github.com/bmizerany/pat v0.0.0-20210406213842-e4b6760bdd6f // indirect
|
||||
github.com/containerd/containerd v1.5.8 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/docker/distribution v2.7.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.4.0 // indirect
|
||||
github.com/docker/go-units v0.4.0 // indirect
|
||||
github.com/eclipse/paho.mqtt.golang v1.3.5 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/gorilla/mux v1.8.0 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/mschoch/smat v0.2.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.0 // indirect
|
||||
github.com/sirupsen/logrus v1.8.1 // indirect
|
||||
github.com/steveyen/gtreap v0.1.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
|
||||
go.etcd.io/bbolt v1.3.6 // indirect
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
|
||||
golang.org/x/net v0.0.0-20211105192438-b53810dc28af // indirect
|
||||
golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20211021150943-2b146023228c // indirect
|
||||
google.golang.org/grpc v1.41.0 // indirect
|
||||
google.golang.org/protobuf v1.27.1 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/xeipuuv/gojsonschema => github.com/warjiang/gojsonschema v1.2.1-0.20210329105853-aa9f9a8cfec7
|
||||
|
||||
10
go.sum
10
go.sum
@@ -276,7 +276,6 @@ github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8Nz
|
||||
github.com/coreos/go-iptables v0.4.3/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU=
|
||||
github.com/coreos/go-iptables v0.4.5/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU=
|
||||
github.com/coreos/go-iptables v0.5.0/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw=
|
||||
github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo=
|
||||
@@ -833,9 +832,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -846,6 +844,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
||||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20220318154914-8dddf5d87bd8 h1:s/+U+w0teGzcoH2mdIlFQ6KfVKGaYpgyGdUefZrn9TU=
|
||||
golang.org/x/exp v0.0.0-20220318154914-8dddf5d87bd8/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -1040,7 +1040,6 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -1058,9 +1057,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
type Hooks struct {
|
||||
DatabaseAfterConnectFuncs []func(ctx context.Context, client driver.Client, name string)
|
||||
IngestionFilterFunc func(ctx context.Context, index *index.Index) (string, error)
|
||||
TicketReadFilterFunc func(ctx context.Context) (string, map[string]interface{}, error)
|
||||
TicketWriteFilterFunc func(ctx context.Context) (string, map[string]interface{}, error)
|
||||
TicketReadFilterFunc func(ctx context.Context) (string, map[string]any, error)
|
||||
TicketWriteFilterFunc func(ctx context.Context) (string, map[string]any, error)
|
||||
GetGroupsFunc func(ctx context.Context, username string) ([]string, error)
|
||||
}
|
||||
|
||||
@@ -26,20 +26,23 @@ func (h *Hooks) IngestionFilter(ctx context.Context, index *index.Index) (string
|
||||
if h.IngestionFilterFunc != nil {
|
||||
return h.IngestionFilterFunc(ctx, index)
|
||||
}
|
||||
|
||||
return "[]", nil
|
||||
}
|
||||
|
||||
func (h *Hooks) TicketReadFilter(ctx context.Context) (string, map[string]interface{}, error) {
|
||||
func (h *Hooks) TicketReadFilter(ctx context.Context) (string, map[string]any, error) {
|
||||
if h.TicketReadFilterFunc != nil {
|
||||
return h.TicketReadFilterFunc(ctx)
|
||||
}
|
||||
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func (h *Hooks) TicketWriteFilter(ctx context.Context) (string, map[string]interface{}, error) {
|
||||
func (h *Hooks) TicketWriteFilter(ctx context.Context) (string, map[string]any, error) {
|
||||
if h.TicketWriteFilterFunc != nil {
|
||||
return h.TicketWriteFilterFunc(ctx)
|
||||
}
|
||||
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
@@ -47,5 +50,6 @@ func (h *Hooks) GetGroups(ctx context.Context, username string) ([]string, error
|
||||
if h.GetGroupsFunc != nil {
|
||||
return h.GetGroupsFunc(ctx, username)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ func (i *Index) Index(incidents []*model.TicketSimpleResponse) {
|
||||
for _, incident := range incidents {
|
||||
if incident.ID == 0 {
|
||||
log.Println(errors.New("no ID"), incident)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -44,8 +45,8 @@ func (i *Index) Index(incidents []*model.TicketSimpleResponse) {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
err := i.internal.Batch(b)
|
||||
if err != nil {
|
||||
|
||||
if err := i.internal.Batch(b); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
@@ -59,6 +60,7 @@ func (i *Index) Search(term string) (ids []string, err error) {
|
||||
for _, match := range result.Hits {
|
||||
ids = append(ids, match.ID)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
@@ -76,6 +78,7 @@ func (i *Index) Truncate() error {
|
||||
return err
|
||||
}
|
||||
i.internal = index
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
)
|
||||
|
||||
func TestIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
term string
|
||||
}
|
||||
@@ -22,7 +24,10 @@ func TestIndex(t *testing.T) {
|
||||
{name: "Not exists", args: args{"bar"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i, cleanup, err := test.Index(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -37,6 +42,7 @@ func TestIndex(t *testing.T) {
|
||||
gotIds, err := i.Search(tt.args.term)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Search() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotIds, tt.wantIds) {
|
||||
@@ -47,6 +53,8 @@ func TestIndex(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIndex_Truncate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
@@ -54,7 +62,10 @@ func TestIndex_Truncate(t *testing.T) {
|
||||
{name: "Truncate"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i, cleanup, err := test.Index(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
18
restore.go
18
restore.go
@@ -29,11 +29,13 @@ func restoreHandler(catalystStorage *storage.Storage, db *database.Database, c *
|
||||
uf, header, err := r.FormFile("backup")
|
||||
if err != nil {
|
||||
api.JSONError(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = Restore(r.Context(), catalystStorage, db, c, uf, header.Size); err != nil {
|
||||
api.JSONError(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -52,7 +54,7 @@ func Restore(ctx context.Context, catalystStorage *storage.Storage, db *database
|
||||
}
|
||||
|
||||
if fsys.Comment != GetVersion() {
|
||||
return errors.New(fmt.Sprintf("wrong version, got: %s, want: %s", fsys.Comment, GetVersion()))
|
||||
return fmt.Errorf("wrong version, got: %s, want: %s", fsys.Comment, GetVersion())
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "catalyst-restore")
|
||||
@@ -89,17 +91,19 @@ func restoreS3(catalystStorage *storage.Storage, p string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreBucket(catalystStorage *storage.Storage, entry fs.DirEntry, minioDir fs.FS) error {
|
||||
_, err := catalystStorage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String(entry.Name())})
|
||||
if err != nil {
|
||||
awsError, ok := err.(awserr.Error)
|
||||
if !ok || (awsError.Code() != s3.ErrCodeBucketAlreadyExists && awsError.Code() != s3.ErrCodeBucketAlreadyOwnedByYou) {
|
||||
return err
|
||||
var awsError awserr.Error
|
||||
if errors.As(err, &awsError) && (awsError.Code() == s3.ErrCodeBucketAlreadyExists || awsError.Code() == s3.ErrCodeBucketAlreadyOwnedByYou) {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
uploader := catalystStorage.Uploader()
|
||||
@@ -115,11 +119,13 @@ func restoreBucket(catalystStorage *storage.Storage, entry fs.DirEntry, minioDir
|
||||
return nil
|
||||
}
|
||||
_, err = uploader.Upload(&s3manager.UploadInput{Body: f, Bucket: pointer.String(entry.Name()), Key: pointer.String(path)})
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -131,6 +137,7 @@ func unzip(archive *zip.Reader, dir string) error {
|
||||
|
||||
if d.IsDir() {
|
||||
_ = os.MkdirAll(path.Join(dir, p), os.ModePerm)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,5 +170,6 @@ func arangorestore(dir string, config *database.Config) error {
|
||||
"--server.database", name,
|
||||
}
|
||||
cmd := exec.Command("arangorestore", args...)
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
34
role/role.go
34
role/role.go
@@ -5,6 +5,8 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/generated/model"
|
||||
)
|
||||
|
||||
@@ -60,23 +62,16 @@ func UserHasRoles(user *model.UserResponse, roles []Role) bool {
|
||||
for _, role := range roles {
|
||||
if !UserHasRole(user, role) {
|
||||
hasRoles = false
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return hasRoles
|
||||
}
|
||||
|
||||
func UserHasRole(user *model.UserResponse, role Role) bool {
|
||||
return ContainsRole(FromStrings(user.Roles), role)
|
||||
}
|
||||
|
||||
func ContainsRole(roles []Role, role Role) bool {
|
||||
for _, r := range roles {
|
||||
if r.String() == role.String() { // || strings.HasPrefix(role.String(), r.String()+":")
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return slices.Contains(FromStrings(user.Roles), role)
|
||||
}
|
||||
|
||||
func Explodes(s []string) []Role {
|
||||
@@ -84,10 +79,10 @@ func Explodes(s []string) []Role {
|
||||
for _, e := range s {
|
||||
roles = append(roles, Explode(e)...)
|
||||
}
|
||||
roles = unique(roles)
|
||||
sort.Slice(roles, func(i, j int) bool {
|
||||
return roles[i].String() < roles[j].String()
|
||||
})
|
||||
roles = slices.Compact(roles)
|
||||
|
||||
return roles
|
||||
}
|
||||
@@ -98,12 +93,15 @@ func Explode(s string) []Role {
|
||||
switch s {
|
||||
case Admin:
|
||||
roles = append(roles, listPrefix(Admin)...)
|
||||
|
||||
fallthrough
|
||||
case Engineer:
|
||||
roles = append(roles, listPrefix(Engineer)...)
|
||||
|
||||
fallthrough
|
||||
case Analyst:
|
||||
roles = append(roles, listPrefix(Analyst)...)
|
||||
|
||||
return roles
|
||||
}
|
||||
|
||||
@@ -128,18 +126,6 @@ func listPrefix(s string) []Role {
|
||||
return roles
|
||||
}
|
||||
|
||||
func unique(l []Role) []Role {
|
||||
keys := make(map[Role]bool)
|
||||
var list []Role
|
||||
for _, entry := range l {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = true
|
||||
list = append(list, entry)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func List() []Role {
|
||||
return []Role{
|
||||
AutomationRead, CurrentuserdataRead, CurrentuserdataWrite,
|
||||
@@ -167,6 +153,7 @@ func Strings(roles []Role) []string {
|
||||
for _, role := range roles {
|
||||
s = append(s, role.String())
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -179,5 +166,6 @@ func FromStrings(s []string) []Role {
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
return roles
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func automationResponseID(automation *model.AutomationResponse) []driver.Documen
|
||||
if automation == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return automationID(automation.ID)
|
||||
}
|
||||
|
||||
@@ -27,6 +28,7 @@ func (s *Service) ListAutomations(ctx context.Context) ([]*model.AutomationRespo
|
||||
|
||||
func (s *Service) CreateAutomation(ctx context.Context, form *model.AutomationForm) (doc *model.AutomationResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreateAutomation", automationResponseID(doc))
|
||||
|
||||
return s.database.AutomationCreate(ctx, form)
|
||||
}
|
||||
|
||||
@@ -36,10 +38,12 @@ func (s *Service) GetAutomation(ctx context.Context, id string) (*model.Automati
|
||||
|
||||
func (s *Service) UpdateAutomation(ctx context.Context, id string, form *model.AutomationForm) (doc *model.AutomationResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateAutomation", automationResponseID(doc))
|
||||
|
||||
return s.database.AutomationUpdate(ctx, id, form)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteAutomation(ctx context.Context, id string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeleteAutomation", automationID(id))
|
||||
|
||||
return s.database.AutomationDelete(ctx, id)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func dashboardResponseID(doc *model.DashboardResponse) []driver.DocumentID {
|
||||
if doc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return templateID(doc.ID)
|
||||
}
|
||||
|
||||
@@ -27,6 +28,7 @@ func (s *Service) ListDashboards(ctx context.Context) ([]*model.DashboardRespons
|
||||
|
||||
func (s *Service) CreateDashboard(ctx context.Context, dashboard *model.Dashboard) (doc *model.DashboardResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreateDashboard", dashboardResponseID(doc))
|
||||
|
||||
return s.database.DashboardCreate(ctx, dashboard)
|
||||
}
|
||||
|
||||
@@ -36,14 +38,16 @@ func (s *Service) GetDashboard(ctx context.Context, id string) (*model.Dashboard
|
||||
|
||||
func (s *Service) UpdateDashboard(ctx context.Context, id string, form *model.Dashboard) (doc *model.DashboardResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateDashboard", dashboardResponseID(doc))
|
||||
|
||||
return s.database.DashboardUpdate(ctx, id, form)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteDashboard(ctx context.Context, id string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeleteDashboard", dashboardID(id))
|
||||
|
||||
return s.database.DashboardDelete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *Service) DashboardData(ctx context.Context, aggregation string, filter *string) (map[string]interface{}, error) {
|
||||
func (s *Service) DashboardData(ctx context.Context, aggregation string, filter *string) (map[string]any, error) {
|
||||
return s.database.WidgetData(ctx, aggregation, filter)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func jobResponseID(job *model.JobResponse) []driver.DocumentID {
|
||||
if job == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jobID(job.ID)
|
||||
}
|
||||
|
||||
@@ -48,5 +49,6 @@ func (s *Service) GetJob(ctx context.Context, id string) (*model.JobResponse, er
|
||||
|
||||
func (s *Service) UpdateJob(ctx context.Context, id string, job *model.JobUpdate) (doc *model.JobResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateJob", jobResponseID(doc))
|
||||
|
||||
return s.database.JobUpdate(ctx, id, job)
|
||||
}
|
||||
|
||||
@@ -9,5 +9,6 @@ import (
|
||||
|
||||
func (s *Service) GetLogs(ctx context.Context, reference string) ([]*model.LogEntry, error) {
|
||||
id, _ := url.QueryUnescape(reference)
|
||||
|
||||
return s.database.LogList(ctx, id)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func playbookResponseID(playbook *model.PlaybookTemplateResponse) []driver.Docum
|
||||
if playbook == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return playbookID(playbook.ID)
|
||||
}
|
||||
|
||||
@@ -27,6 +28,7 @@ func (s *Service) ListPlaybooks(ctx context.Context) ([]*model.PlaybookTemplateR
|
||||
|
||||
func (s *Service) CreatePlaybook(ctx context.Context, form *model.PlaybookTemplateForm) (doc *model.PlaybookTemplateResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreatePlaybook", playbookResponseID(doc))
|
||||
|
||||
return s.database.PlaybookCreate(ctx, form)
|
||||
}
|
||||
|
||||
@@ -36,10 +38,12 @@ func (s *Service) GetPlaybook(ctx context.Context, id string) (*model.PlaybookTe
|
||||
|
||||
func (s *Service) UpdatePlaybook(ctx context.Context, id string, form *model.PlaybookTemplateForm) (doc *model.PlaybookTemplateResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdatePlaybook", playbookResponseID(doc))
|
||||
|
||||
return s.database.PlaybookUpdate(ctx, id, form)
|
||||
}
|
||||
|
||||
func (s *Service) DeletePlaybook(ctx context.Context, id string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeletePlaybook", playbookID(id))
|
||||
|
||||
return s.database.PlaybookDelete(ctx, id)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/arangodb/go-driver"
|
||||
|
||||
@@ -33,6 +34,10 @@ func (s *Service) publishRequest(ctx context.Context, err error, function string
|
||||
userID = user.ID
|
||||
}
|
||||
|
||||
go s.bus.PublishRequest(userID, function, ids)
|
||||
go func() {
|
||||
if err := s.bus.PublishRequest(userID, function, ids); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func templateResponseID(template *model.TicketTemplateResponse) []driver.Documen
|
||||
if template == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return templateID(template.ID)
|
||||
}
|
||||
|
||||
@@ -27,6 +28,7 @@ func (s *Service) ListTemplates(ctx context.Context) ([]*model.TicketTemplateRes
|
||||
|
||||
func (s *Service) CreateTemplate(ctx context.Context, form *model.TicketTemplateForm) (doc *model.TicketTemplateResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreateTemplate", templateResponseID(doc))
|
||||
|
||||
return s.database.TemplateCreate(ctx, form)
|
||||
}
|
||||
|
||||
@@ -36,10 +38,12 @@ func (s *Service) GetTemplate(ctx context.Context, id string) (*model.TicketTemp
|
||||
|
||||
func (s *Service) UpdateTemplate(ctx context.Context, id string, form *model.TicketTemplateForm) (doc *model.TicketTemplateResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateTemplate", templateResponseID(doc))
|
||||
|
||||
return s.database.TemplateUpdate(ctx, id, form)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteTemplate(ctx context.Context, id string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeleteTemplate", templateID(id))
|
||||
|
||||
return s.database.TemplateDelete(ctx, id)
|
||||
}
|
||||
|
||||
@@ -18,11 +18,13 @@ func ticketWithTicketsID(ticketResponse *model.TicketWithTickets) []driver.Docum
|
||||
if ticketResponse == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ticketID(ticketResponse.ID)
|
||||
}
|
||||
|
||||
func ticketID(ticketID int64) []driver.DocumentID {
|
||||
id := fmt.Sprintf("%s/%d", database.TicketCollectionName, ticketID)
|
||||
|
||||
return []driver.DocumentID{driver.DocumentID(id)}
|
||||
}
|
||||
|
||||
@@ -31,6 +33,7 @@ func ticketIDs(ticketResponses []*model.TicketResponse) []driver.DocumentID {
|
||||
for _, ticketResponse := range ticketResponses {
|
||||
ids = append(ids, ticketID(ticketResponse.ID)...)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
@@ -63,6 +66,7 @@ func (s *Service) CreateTicket(ctx context.Context, form *model.TicketForm) (doc
|
||||
if len(createdTickets) > 0 {
|
||||
return createdTickets[0], err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -72,6 +76,7 @@ func (s *Service) CreateTicketBatch(ctx context.Context, ticketFormArray *model.
|
||||
}
|
||||
createdTickets, err := s.database.TicketBatchCreate(ctx, *ticketFormArray)
|
||||
defer s.publishRequest(ctx, err, "CreateTicket", ticketIDs(createdTickets))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -81,16 +86,19 @@ func (s *Service) GetTicket(ctx context.Context, i int64) (*model.TicketWithTick
|
||||
|
||||
func (s *Service) UpdateTicket(ctx context.Context, i int64, ticket *model.Ticket) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateTicket", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.TicketUpdate(ctx, i, ticket)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteTicket(ctx context.Context, i int64) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeleteTicket", ticketID(i))
|
||||
|
||||
return s.database.TicketDelete(ctx, i)
|
||||
}
|
||||
|
||||
func (s *Service) AddArtifact(ctx context.Context, i int64, artifact *model.Artifact) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "AddArtifact", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.AddArtifact(ctx, i, artifact)
|
||||
}
|
||||
|
||||
@@ -100,16 +108,19 @@ func (s *Service) GetArtifact(ctx context.Context, i int64, s2 string) (*model.A
|
||||
|
||||
func (s *Service) SetArtifact(ctx context.Context, i int64, s2 string, artifact *model.Artifact) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "SetArtifact", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.ArtifactUpdate(ctx, i, s2, artifact)
|
||||
}
|
||||
|
||||
func (s *Service) RemoveArtifact(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "RemoveArtifact", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.RemoveArtifact(ctx, i, s2)
|
||||
}
|
||||
|
||||
func (s *Service) EnrichArtifact(ctx context.Context, i int64, s2 string, form *model.EnrichmentForm) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "EnrichArtifact", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.EnrichArtifact(ctx, i, s2, form)
|
||||
}
|
||||
|
||||
@@ -123,46 +134,55 @@ func (s *Service) RunArtifact(ctx context.Context, id int64, name string, automa
|
||||
|
||||
jobID := uuid.NewString()
|
||||
origin := &model.Origin{ArtifactOrigin: &model.ArtifactOrigin{TicketId: id, Artifact: name}}
|
||||
|
||||
return s.bus.PublishJob(jobID, automation, map[string]string{"default": name}, &model.Context{Artifact: artifact}, origin)
|
||||
}
|
||||
|
||||
func (s *Service) AddComment(ctx context.Context, i int64, form *model.CommentForm) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "AddComment", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.AddComment(ctx, i, form)
|
||||
}
|
||||
|
||||
func (s *Service) RemoveComment(ctx context.Context, i int64, i2 int) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "RemoveComment", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.RemoveComment(ctx, i, int64(i2))
|
||||
}
|
||||
|
||||
func (s *Service) AddTicketPlaybook(ctx context.Context, i int64, form *model.PlaybookTemplateForm) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "AddTicketPlaybook", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.AddTicketPlaybook(ctx, i, form)
|
||||
}
|
||||
|
||||
func (s *Service) RemoveTicketPlaybook(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "RemoveTicketPlaybook", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.RemoveTicketPlaybook(ctx, i, s2)
|
||||
}
|
||||
|
||||
func (s *Service) SetTaskData(ctx context.Context, i int64, s3 string, s2 string, data map[string]interface{}) (doc *model.TicketWithTickets, err error) {
|
||||
func (s *Service) SetTaskData(ctx context.Context, i int64, s3 string, s2 string, data map[string]any) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "SetTask", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.TaskUpdateData(ctx, i, s3, s2, data)
|
||||
}
|
||||
|
||||
func (s *Service) SetTaskOwner(ctx context.Context, i int64, s3 string, s2 string, owner string) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "SetTask", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.TaskUpdateOwner(ctx, i, s3, s2, owner)
|
||||
}
|
||||
|
||||
func (s *Service) CompleteTask(ctx context.Context, i int64, s3 string, s2 string, m map[string]interface{}) (doc *model.TicketWithTickets, err error) {
|
||||
func (s *Service) CompleteTask(ctx context.Context, i int64, s3 string, s2 string, m map[string]any) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "CompleteTask", ticketWithTicketsID(doc))
|
||||
|
||||
return s.database.TaskComplete(ctx, i, s3, s2, m)
|
||||
}
|
||||
|
||||
func (s *Service) RunTask(ctx context.Context, i int64, s3 string, s2 string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "RunTask", ticketID(i))
|
||||
|
||||
return s.database.TaskRun(ctx, i, s3, s2)
|
||||
}
|
||||
|
||||
@@ -171,11 +191,13 @@ func (s *Service) SetReferences(ctx context.Context, i int64, references *model.
|
||||
return nil, &api.HTTPError{Status: http.StatusUnprocessableEntity, Internal: errors.New("no references given")}
|
||||
}
|
||||
defer s.publishRequest(ctx, err, "SetReferences", ticketID(i))
|
||||
|
||||
return s.database.SetReferences(ctx, i, *references)
|
||||
}
|
||||
|
||||
func (s *Service) SetSchema(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) {
|
||||
defer s.publishRequest(ctx, err, "SetSchema", ticketID(i))
|
||||
|
||||
return s.database.SetTemplate(ctx, i, s2)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ func ticketTypeResponseID(ticketType *model.TicketTypeResponse) []driver.Documen
|
||||
if ticketType == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return userDataID(ticketType.ID)
|
||||
}
|
||||
|
||||
@@ -27,6 +28,7 @@ func (s *Service) ListTicketTypes(ctx context.Context) ([]*model.TicketTypeRespo
|
||||
|
||||
func (s *Service) CreateTicketType(ctx context.Context, form *model.TicketTypeForm) (doc *model.TicketTypeResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreateTicketType", ticketTypeResponseID(doc))
|
||||
|
||||
return s.database.TicketTypeCreate(ctx, form)
|
||||
}
|
||||
|
||||
@@ -36,10 +38,12 @@ func (s *Service) GetTicketType(ctx context.Context, id string) (*model.TicketTy
|
||||
|
||||
func (s *Service) UpdateTicketType(ctx context.Context, id string, form *model.TicketTypeForm) (doc *model.TicketTypeResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateTicketType", ticketTypeResponseID(doc))
|
||||
|
||||
return s.database.TicketTypeUpdate(ctx, id, form)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteTicketType(ctx context.Context, id string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeleteTicketType", ticketTypeID(id))
|
||||
|
||||
return s.database.TicketTypeDelete(ctx, id)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func newUserResponseID(user *model.NewUserResponse) []driver.DocumentID {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return userID(user.ID)
|
||||
}
|
||||
|
||||
@@ -23,6 +24,7 @@ func userResponseID(user *model.UserResponse) []driver.DocumentID {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return userID(user.ID)
|
||||
}
|
||||
|
||||
@@ -36,6 +38,7 @@ func (s *Service) ListUsers(ctx context.Context) ([]*model.UserResponse, error)
|
||||
|
||||
func (s *Service) CreateUser(ctx context.Context, form *model.UserForm) (doc *model.NewUserResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreateUser", newUserResponseID(doc))
|
||||
|
||||
return s.database.UserCreate(ctx, form)
|
||||
}
|
||||
|
||||
@@ -45,11 +48,13 @@ func (s *Service) GetUser(ctx context.Context, s2 string) (*model.UserResponse,
|
||||
|
||||
func (s *Service) UpdateUser(ctx context.Context, s2 string, form *model.UserForm) (doc *model.UserResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "UpdateUser", userID(s2))
|
||||
|
||||
return s.database.UserUpdate(ctx, s2, form)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteUser(ctx context.Context, s2 string) (err error) {
|
||||
defer s.publishRequest(ctx, err, "DeleteUser", userID(s2))
|
||||
|
||||
return s.database.UserDelete(ctx, s2)
|
||||
}
|
||||
|
||||
@@ -59,5 +64,6 @@ func (s *Service) CurrentUser(ctx context.Context) (*model.UserResponse, error)
|
||||
return nil, errors.New("no user in context")
|
||||
}
|
||||
s.publishRequest(ctx, nil, "CurrentUser", userResponseID(user))
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func userDataResponseID(userData *model.UserDataResponse) []driver.DocumentID {
|
||||
if userData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return userDataID(userData.ID)
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ func (s *Service) GetUserData(ctx context.Context, id string) (*model.UserDataRe
|
||||
|
||||
func (s *Service) UpdateUserData(ctx context.Context, id string, data *model.UserData) (doc *model.UserDataResponse, err error) {
|
||||
defer s.publishRequest(ctx, err, "CreateUser", userDataResponseID(doc))
|
||||
|
||||
return s.database.UserDataUpdate(ctx, id, data)
|
||||
}
|
||||
|
||||
@@ -52,5 +54,6 @@ func (s *Service) UpdateCurrentUserData(ctx context.Context, data *model.UserDat
|
||||
}
|
||||
|
||||
defer s.publishRequest(ctx, err, "UpdateCurrentUserData", userDataResponseID(doc))
|
||||
|
||||
return s.database.UserDataUpdate(ctx, user.ID, data)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
@@ -29,6 +31,7 @@ func New(config *Config) (*Storage, error) {
|
||||
DisableSSL: aws.Bool(true),
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
})
|
||||
|
||||
return &Storage{s}, err
|
||||
}
|
||||
|
||||
@@ -39,17 +42,20 @@ func (s *Storage) S3() *s3.S3 {
|
||||
func (s *Storage) Downloader() *s3manager.Downloader {
|
||||
d := s3manager.NewDownloader(s.session)
|
||||
d.Concurrency = 1
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func (s *Storage) Uploader() *s3manager.Uploader {
|
||||
d := s3manager.NewUploader(s.session)
|
||||
d.Concurrency = 1
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func (s *Storage) DeleteBucket(name string) error {
|
||||
_, err := s.S3().DeleteBucket(&s3.DeleteBucketInput{Bucket: pointer.String("catalyst-" + name)})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -61,11 +67,13 @@ func CreateBucket(client *s3.S3, ticketID string) error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
awsError, ok := err.(awserr.Error)
|
||||
if !ok || (awsError.Code() != s3.ErrCodeBucketAlreadyExists && awsError.Code() != s3.ErrCodeBucketAlreadyOwnedByYou) {
|
||||
return err
|
||||
var awsError awserr.Error
|
||||
if errors.As(err, &awsError) && (awsError.Code() == s3.ErrCodeBucketAlreadyExists || awsError.Code() == s3.ErrCodeBucketAlreadyOwnedByYou) {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ import (
|
||||
)
|
||||
|
||||
func TestBackupAndRestore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
if runtime.GOARCH == "arm64" {
|
||||
@@ -41,7 +43,10 @@ func TestBackupAndRestore(t *testing.T) {
|
||||
{name: "Backup", want: want{status: http.StatusOK}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, _, server, err := Catalyst(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -75,6 +80,8 @@ func TestBackupAndRestore(t *testing.T) {
|
||||
}
|
||||
|
||||
func assertBackup(t *testing.T, server *catalyst.Server) []byte {
|
||||
t.Helper()
|
||||
|
||||
// setup request
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/backup/create", nil)
|
||||
req.Header.Set("PRIVATE-TOKEN", "test")
|
||||
@@ -97,6 +104,8 @@ func assertBackup(t *testing.T, server *catalyst.Server) []byte {
|
||||
}
|
||||
|
||||
func assertZipFile(t *testing.T, r *zip.Reader) {
|
||||
t.Helper()
|
||||
|
||||
var names []string
|
||||
for _, f := range r.File {
|
||||
names = append(names, f.Name)
|
||||
@@ -120,9 +129,11 @@ func clearAllDatabases(server *catalyst.Server) {
|
||||
}
|
||||
|
||||
func deleteAllBuckets(t *testing.T, server *catalyst.Server) {
|
||||
t.Helper()
|
||||
|
||||
buckets, err := server.Storage.S3().ListBuckets(&s3.ListBucketsInput{})
|
||||
for _, bucket := range buckets.Buckets {
|
||||
server.Storage.S3().DeleteBucket(&s3.DeleteBucketInput{
|
||||
_, _ = server.Storage.S3().DeleteBucket(&s3.DeleteBucketInput{
|
||||
Bucket: bucket.Name,
|
||||
})
|
||||
}
|
||||
@@ -133,6 +144,8 @@ func deleteAllBuckets(t *testing.T, server *catalyst.Server) {
|
||||
}
|
||||
|
||||
func assertRestore(t *testing.T, zipB []byte, server *catalyst.Server) {
|
||||
t.Helper()
|
||||
|
||||
bodyBuf := &bytes.Buffer{}
|
||||
bodyWriter := multipart.NewWriter(bodyBuf)
|
||||
fileWriter, err := bodyWriter.CreateFormFile("backup", "backup.zip")
|
||||
@@ -166,7 +179,7 @@ func assertRestore(t *testing.T, zipB []byte, server *catalyst.Server) {
|
||||
func createFile(ctx context.Context, server *catalyst.Server) {
|
||||
buf := bytes.NewBufferString("test text")
|
||||
|
||||
server.Storage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String("catalyst-8125")})
|
||||
_, _ = server.Storage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String("catalyst-8125")})
|
||||
|
||||
if _, err := server.Storage.Uploader().Upload(&s3manager.UploadInput{Body: buf, Bucket: pointer.String("catalyst-8125"), Key: pointer.String("test.txt")}); err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -178,6 +191,8 @@ func createFile(ctx context.Context, server *catalyst.Server) {
|
||||
}
|
||||
|
||||
func assertTicketExists(t *testing.T, server *catalyst.Server) {
|
||||
t.Helper()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/tickets/8125", nil)
|
||||
req.Header.Set("PRIVATE-TOKEN", "test")
|
||||
|
||||
@@ -202,6 +217,8 @@ func assertTicketExists(t *testing.T, server *catalyst.Server) {
|
||||
}
|
||||
|
||||
func assertFileExists(t *testing.T, server *catalyst.Server) {
|
||||
t.Helper()
|
||||
|
||||
obj, err := server.Storage.S3().GetObject(&s3.GetObjectInput{
|
||||
Bucket: aws.String("catalyst-8125"),
|
||||
Key: aws.String("test.txt"),
|
||||
@@ -215,6 +232,8 @@ func assertFileExists(t *testing.T, server *catalyst.Server) {
|
||||
}
|
||||
|
||||
func includes(t *testing.T, names []string, s string) bool {
|
||||
t.Helper()
|
||||
|
||||
for _, name := range names {
|
||||
match, err := regexp.MatchString(s, name)
|
||||
if err != nil {
|
||||
@@ -225,10 +244,13 @@ func includes(t *testing.T, names []string, s string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func readZipFile(t *testing.T, b []byte) *zip.Reader {
|
||||
t.Helper()
|
||||
|
||||
buf := bytes.NewReader(b)
|
||||
|
||||
zr, err := zip.NewReader(buf, int64(buf.Len()))
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/generated/pointer"
|
||||
)
|
||||
|
||||
var bobSetting = &model.UserData{Email: pointer.String("bob@example.org"), Name: pointer.String("Bob Bad")}
|
||||
var bobForm = &model.UserForm{ID: "bob", Blocked: false, Roles: []string{"admin"}}
|
||||
var Bob = &model.UserResponse{ID: "bob", Blocked: false, Roles: []string{"admin"}}
|
||||
var (
|
||||
bobSetting = &model.UserData{Email: pointer.String("bob@example.org"), Name: pointer.String("Bob Bad")}
|
||||
bobForm = &model.UserForm{ID: "bob", Blocked: false, Roles: []string{"admin"}}
|
||||
Bob = &model.UserResponse{ID: "bob", Blocked: false, Roles: []string{"admin"}}
|
||||
)
|
||||
|
||||
func SetupTestData(ctx context.Context, db *database.Database) error {
|
||||
if err := db.UserDataCreate(ctx, "bob", bobSetting); err != nil {
|
||||
@@ -109,5 +111,6 @@ func parse(s string) *time.Time {
|
||||
}
|
||||
|
||||
utc := modified.UTC()
|
||||
|
||||
return &utc
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
)
|
||||
|
||||
func TestJob(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
_, _, server, err := Catalyst(t)
|
||||
@@ -27,7 +29,7 @@ func TestJob(t *testing.T) {
|
||||
|
||||
b, err := json.Marshal(model.JobForm{
|
||||
Automation: "hash.sha1",
|
||||
Payload: map[string]interface{}{"default": "test"},
|
||||
Payload: map[string]any{"default": "test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -52,11 +54,14 @@ func TestJob(t *testing.T) {
|
||||
|
||||
output := gjson.GetBytes(job, "output.hash").String()
|
||||
assert.Equal(t, "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3", output)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func request(t *testing.T, server chi.Router, method, url string, data io.Reader) []byte {
|
||||
t.Helper()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// setup request
|
||||
|
||||
@@ -25,10 +25,15 @@ func (testClock) Now() time.Time {
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctime.DefaultClock = testClock{}
|
||||
|
||||
for _, tt := range api.Tests {
|
||||
tt := tt
|
||||
t.Run(tt.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, _, _, _, _, db, _, server, cleanup, err := Server(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -73,8 +78,10 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func jsonEqual(t *testing.T, name string, got io.Reader, want interface{}) {
|
||||
var gotObject, wantObject interface{}
|
||||
func jsonEqual(t *testing.T, name string, got io.Reader, want any) {
|
||||
t.Helper()
|
||||
|
||||
var gotObject, wantObject any
|
||||
|
||||
// load bytes
|
||||
wantBytes, err := json.Marshal(want)
|
||||
|
||||
18
test/test.go
18
test/test.go
@@ -75,6 +75,8 @@ func Config(ctx context.Context) (*catalyst.Config, error) {
|
||||
}
|
||||
|
||||
func Index(t *testing.T) (*index.Index, func(), error) {
|
||||
t.Helper()
|
||||
|
||||
dir, err := os.MkdirTemp("", "catalyst-test-"+cleanName(t))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -84,10 +86,13 @@ func Index(t *testing.T) (*index.Index, func(), error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return catalystIndex, func() { catalystIndex.Close(); os.RemoveAll(dir) }, nil
|
||||
}
|
||||
|
||||
func Bus(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, error) {
|
||||
t.Helper()
|
||||
|
||||
ctx := Context()
|
||||
|
||||
config, err := Config(ctx)
|
||||
@@ -99,10 +104,13 @@ func Bus(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ctx, config, catalystBus, err
|
||||
}
|
||||
|
||||
func DB(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, func(), error) {
|
||||
t.Helper()
|
||||
|
||||
ctx, config, rbus, err := Bus(t)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
@@ -146,6 +154,8 @@ func DB(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index
|
||||
}
|
||||
|
||||
func Service(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, *service.Service, func(), error) {
|
||||
t.Helper()
|
||||
|
||||
ctx, config, rbus, catalystIndex, catalystStorage, db, cleanup, err := DB(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -160,6 +170,8 @@ func Service(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.
|
||||
}
|
||||
|
||||
func Server(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, *service.Service, chi.Router, func(), error) {
|
||||
t.Helper()
|
||||
|
||||
ctx, config, rbus, catalystIndex, catalystStorage, db, catalystService, cleanup, err := Service(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -177,6 +189,8 @@ func Server(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.I
|
||||
}
|
||||
|
||||
func Catalyst(t *testing.T) (context.Context, *catalyst.Config, *catalyst.Server, error) {
|
||||
t.Helper()
|
||||
|
||||
ctx := Context()
|
||||
|
||||
config, err := Config(ctx)
|
||||
@@ -189,13 +203,17 @@ func Catalyst(t *testing.T) (context.Context, *catalyst.Config, *catalyst.Server
|
||||
c, err := catalyst.New(&hooks.Hooks{
|
||||
DatabaseAfterConnectFuncs: []func(ctx context.Context, client driver.Client, name string){Clear},
|
||||
}, config)
|
||||
|
||||
return ctx, config, c, err
|
||||
}
|
||||
|
||||
func cleanName(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
name := t.Name()
|
||||
name = strings.ReplaceAll(name, " ", "")
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
|
||||
return strings.ReplaceAll(name, "#", "_")
|
||||
}
|
||||
|
||||
|
||||
@@ -9,14 +9,16 @@ import (
|
||||
)
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
method string
|
||||
url string
|
||||
data interface{}
|
||||
data any
|
||||
}
|
||||
type want struct {
|
||||
status int
|
||||
body interface{}
|
||||
body any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -27,7 +29,10 @@ func TestUser(t *testing.T) {
|
||||
{name: "ListUsers", args: args{method: http.MethodGet, url: "/users"}, want: want{status: http.StatusOK}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, _, _, _, _, _, server, cleanup, err := Server(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -3,15 +3,23 @@ package ui
|
||||
import "testing"
|
||||
|
||||
func TestUI(t *testing.T) {
|
||||
requiredFiles := []string{
|
||||
"dist/index.html",
|
||||
"dist/favicon.ico",
|
||||
"dist/manifest.json",
|
||||
"dist/img",
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{"index.html", "dist/index.html"},
|
||||
{"favicon.ico", "dist/favicon.ico"},
|
||||
{"manifest.json", "dist/manifest.json"},
|
||||
{"img", "dist/img"},
|
||||
}
|
||||
for _, requiredFile := range requiredFiles {
|
||||
t.Run("Require "+requiredFile, func(t *testing.T) {
|
||||
f, err := UI.Open(requiredFile)
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f, err := UI.Open(tt.path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ func (wb *websocketBroker) NewWebsocket() (string, chan []byte) {
|
||||
wb.mu.Lock()
|
||||
wb.clients[id] = channel
|
||||
wb.mu.Unlock()
|
||||
|
||||
return id, channel
|
||||
}
|
||||
|
||||
@@ -49,7 +50,7 @@ func handleWebSocket(catalystBus *bus.Bus) http.HandlerFunc {
|
||||
|
||||
// send all messages from bus to websocket
|
||||
err := catalystBus.SubscribeDatabaseUpdate(func(msg *bus.DatabaseUpdateMsg) {
|
||||
b, err := json.Marshal(map[string]interface{}{
|
||||
b, err := json.Marshal(map[string]any{
|
||||
"action": "update",
|
||||
"ids": msg.IDs,
|
||||
})
|
||||
@@ -67,6 +68,7 @@ func handleWebSocket(catalystBus *bus.Bus) http.HandlerFunc {
|
||||
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
api.JSONError(w, errors.New("upgrade failed"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user