Migrate to Go 1.18 (#45)

* Migrate to Go 1.18 and add linters
This commit is contained in:
Jonas Plum
2022-03-20 03:17:18 +01:00
committed by GitHub
parent 03a4806d45
commit 2bad1f5f28
88 changed files with 1430 additions and 868 deletions
+16 -4
View File
@@ -9,13 +9,25 @@ env:
IMAGE_NAME: ${{ github.repository }} IMAGE_NAME: ${{ github.repository }}
jobs: jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with: { go-version: '1.18' }
- uses: actions/checkout@v2
- run: |
mkdir -p ui/dist/img
touch ui/dist/index.html ui/dist/favicon.ico ui/dist/manifest.json ui/dist/img/fake.png
- uses: golangci/golangci-lint-action@v3
test: test:
name: Test name: Test
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: { GIN_MODE: test } env: { GIN_MODE: test }
steps: steps:
- uses: actions/setup-go@v2 - uses: actions/setup-go@v3
with: { go-version: '1.17' } with: { go-version: '1.18' }
- uses: actions/setup-node@v2 - uses: actions/setup-node@v2
with: { node-version: '14' } with: { node-version: '14' }
- uses: actions/checkout@v2 - uses: actions/checkout@v2
@@ -51,8 +63,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [ build-npm, test ] needs: [ build-npm, test ]
steps: steps:
- uses: actions/setup-go@v2 - uses: actions/setup-go@v3
with: { go-version: '1.17' } with: { go-version: '1.18' }
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions/download-artifact@v2 - uses: actions/download-artifact@v2
with: { name: ui, path: ui/dist } with: { name: ui, path: ui/dist }
+116
View File
@@ -0,0 +1,116 @@
run:
go: "1.18"
skip-dirs:
- generated
linters:
enable:
- asciicheck
- containedctx
- decorder
- depguard
- dogsled
- durationcheck
- errchkjson
- errname
- errorlint
- exhaustive
- exportloopref
- forbidigo
- forcetypeassert
- gci
- gocritic
- godot
- gofmt
- gofumpt
- goheader
- goimports
- gomodguard
- goprintffuncname
- gosec
- grouper
- ifshort
- importas
- ireturn
- misspell
- nakedret
- nilnil
- nlreturn
- nolintlint
- paralleltest
- predeclared
- promlinter
- revive
- tenv
- thelper
- unconvert
- whitespace
disable:
# go 1.18
- bodyclose
- contextcheck
- gosimple
- nilerr
- noctx
- rowserrcheck
- sqlclosecheck
- staticcheck
- stylecheck
- tparallel
- unparam
- unused
- wastedassign
# complexity
- cyclop
- gocognit
- gocyclo
- maintidx
- nestif
# disable
- dupl
- exhaustivestruct
- funlen
- gochecknoglobals
- gochecknoinits
- goconst
- godox
- goerr113
- gomnd
- gomoddirectives
- lll
- makezero
- prealloc
- structcheck
- tagliatelle
- testpackage
- varnamelen
- wrapcheck
- wsl
linters-settings:
gci:
sections:
- standard
- default
- prefix(github.com/SecurityBrewery/catalyst)
ireturn:
allow:
- error
- context.Context
- go-driver.Cursor
- go-driver.Collection
- go-driver.Database
- chi.Router
issues:
exclude-rules:
- path: caql
text: "var-naming: don't use underscores"
- path: database/user.go
text: "G404"
linters: [ gosec ]
- path: caql/function.go
text: "G404"
linters: [ gosec ]
- path: caql
linters: [ forcetypeassert ]
+31 -17
View File
@@ -2,15 +2,16 @@ package catalyst
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net/http" "net/http"
"strings" "strings"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/exp/slices"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/SecurityBrewery/catalyst/database" "github.com/SecurityBrewery/catalyst/database"
@@ -43,6 +44,7 @@ func (c *AuthConfig) Verifier(ctx context.Context) (*oidc.IDTokenVerifier, error
return nil, err return nil, err
} }
} }
return c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}), nil return c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}), nil
} }
@@ -81,12 +83,14 @@ func bearerAuth(db *database.Database, authHeader string, iss string, config *Au
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(authHeader, "Bearer ") { if !strings.HasPrefix(authHeader, "Bearer ") {
api.JSONErrorStatus(w, http.StatusUnauthorized, errors.New("no bearer token")) api.JSONErrorStatus(w, http.StatusUnauthorized, errors.New("no bearer token"))
return return
} }
claims, apiError := verifyClaims(r, config, authHeader[7:]) claims, apiError := verifyClaims(r, config, authHeader[7:])
if apiError != nil { if apiError != nil {
api.JSONErrorStatus(w, apiError.Status, apiError.Internal) api.JSONErrorStatus(w, apiError.Status, apiError.Internal)
return return
} }
@@ -100,6 +104,7 @@ func bearerAuth(db *database.Database, authHeader string, iss string, config *Au
r, err := setContextClaims(r, db, claims, config) r, err := setContextClaims(r, db, claims, config)
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err)) api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err))
return return
} }
@@ -116,6 +121,7 @@ func keyAuth(db *database.Database, keyHeader string) func(next http.Handler) ht
key, err := db.UserByHash(r.Context(), h) key, err := db.UserByHash(r.Context(), h)
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not verify private token: %w", err)) api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not verify private token: %w", err))
return return
} }
@@ -132,16 +138,19 @@ func sessionAuth(db *database.Database, config *AuthConfig) func(next http.Handl
claims, noCookie, err := claimsCookie(r) claims, noCookie, err := claimsCookie(r)
if err != nil { if err != nil {
api.JSONError(w, err) api.JSONError(w, err)
return return
} }
if noCookie { if noCookie {
redirectToLogin(w, r, config.OAuth2) redirectToLogin(w, r, config.OAuth2)
return return
} }
r, err = setContextClaims(r, db, claims, config) r, err = setContextClaims(r, db, claims, config)
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err)) api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err))
return return
} }
@@ -150,7 +159,7 @@ func sessionAuth(db *database.Database, config *AuthConfig) func(next http.Handl
} }
} }
func setContextClaims(r *http.Request, db *database.Database, claims map[string]interface{}, config *AuthConfig) (*http.Request, error) { func setContextClaims(r *http.Request, db *database.Database, claims map[string]any, config *AuthConfig) (*http.Request, error) {
newUser, newSetting, err := mapUserAndSettings(claims, config) newUser, newSetting, err := mapUserAndSettings(claims, config)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -182,7 +191,7 @@ func setContextUser(r *http.Request, user *model.UserResponse, hooks *hooks.Hook
return busdb.SetContext(r, user) return busdb.SetContext(r, user)
} }
func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*model.UserForm, *model.UserData, error) { func mapUserAndSettings(claims map[string]any, config *AuthConfig) (*model.UserForm, *model.UserData, error) {
// handle Bearer tokens // handle Bearer tokens
// if typ, ok := claims["typ"]; ok && typ == "Bearer" { // if typ, ok := claims["typ"]; ok && typ == "Bearer" {
// return &model.User{ // return &model.User{
@@ -208,8 +217,8 @@ func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*mod
name = "" name = ""
} }
var roles = role.Strings(config.AuthDefaultRoles) roles := role.Strings(config.AuthDefaultRoles)
if contains(config.AuthAdminUsers, username) { if slices.Contains(config.AuthAdminUsers, username) {
roles = append(roles, role.Admin) roles = append(roles, role.Admin)
} }
@@ -223,20 +232,12 @@ func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*mod
}, nil }, nil
} }
func contains(l []string, s string) bool { func getString(m map[string]any, key string) (string, error) {
for _, e := range l {
if e == s {
return true
}
}
return false
}
func getString(m map[string]interface{}, key string) (string, error) {
if v, ok := m[key]; ok { if v, ok := m[key]; ok {
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
return s, nil return s, nil
} }
return "", fmt.Errorf("mapping of %s failed, wrong type (%T)", key, v) return "", fmt.Errorf("mapping of %s failed, wrong type (%T)", key, v)
} }
@@ -247,12 +248,14 @@ func redirectToLogin(w http.ResponseWriter, r *http.Request, oauth2Config *oauth
state, err := state() state, err := state()
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("generating state failed")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("generating state failed"))
return return
} }
setStateCookie(w, state) setStateCookie(w, state)
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound) http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound)
return return
} }
@@ -262,11 +265,13 @@ func AuthorizeBlockedUser() func(http.Handler) http.Handler {
user, ok := busdb.UserFromContext(r.Context()) user, ok := busdb.UserFromContext(r.Context())
if !ok { if !ok {
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context"))
return return
} }
if user.Blocked { if user.Blocked {
api.JSONErrorStatus(w, http.StatusForbidden, errors.New("user is blocked")) api.JSONErrorStatus(w, http.StatusForbidden, errors.New("user is blocked"))
return return
} }
@@ -281,11 +286,13 @@ func AuthorizeRole(roles []string) func(http.Handler) http.Handler {
user, ok := busdb.UserFromContext(r.Context()) user, ok := busdb.UserFromContext(r.Context())
if !ok { if !ok {
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context"))
return return
} }
if !role.UserHasRoles(user, role.FromStrings(roles)) { if !role.UserHasRoles(user, role.FromStrings(roles)) {
api.JSONErrorStatus(w, http.StatusForbidden, fmt.Errorf("missing role %s has %s", roles, user.Roles)) api.JSONErrorStatus(w, http.StatusForbidden, fmt.Errorf("missing role %s has %s", roles, user.Roles))
return return
} }
@@ -299,17 +306,20 @@ func callback(config *AuthConfig) http.HandlerFunc {
state, err := stateCookie(r) state, err := stateCookie(r)
if err != nil || state == "" { if err != nil || state == "" {
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state missing")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state missing"))
return return
} }
if state != r.URL.Query().Get("state") { if state != r.URL.Query().Get("state") {
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state mismatch")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state mismatch"))
return return
} }
oauth2Token, err := config.OAuth2.Exchange(r.Context(), r.URL.Query().Get("code")) oauth2Token, err := config.OAuth2.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("oauth2 exchange failed: %w", err)) api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("oauth2 exchange failed: %w", err))
return return
} }
@@ -317,12 +327,14 @@ func callback(config *AuthConfig) http.HandlerFunc {
rawIDToken, ok := oauth2Token.Extra("id_token").(string) rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok { if !ok {
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("missing id token")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("missing id token"))
return return
} }
claims, apiError := verifyClaims(r, config, rawIDToken) claims, apiError := verifyClaims(r, config, rawIDToken)
if apiError != nil { if apiError != nil {
api.JSONErrorStatus(w, apiError.Status, apiError.Internal) api.JSONErrorStatus(w, apiError.Status, apiError.Internal)
return return
} }
@@ -337,10 +349,11 @@ func state() (string, error) {
if _, err := rand.Read(rnd); err != nil { if _, err := rand.Read(rnd); err != nil {
return "", err return "", err
} }
return base64.URLEncoding.EncodeToString(rnd), nil return base64.URLEncoding.EncodeToString(rnd), nil
} }
func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[string]interface{}, *api.HTTPError) { func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[string]any, *api.HTTPError) {
verifier, err := config.Verifier(r.Context()) verifier, err := config.Verifier(r.Context())
if err != nil { if err != nil {
return nil, &api.HTTPError{Status: http.StatusUnauthorized, Internal: fmt.Errorf("could not verify: %w", err)} return nil, &api.HTTPError{Status: http.StatusUnauthorized, Internal: fmt.Errorf("could not verify: %w", err)}
@@ -350,9 +363,10 @@ func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[s
return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("could not verify bearer token: %w", err)} return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("could not verify bearer token: %w", err)}
} }
var claims map[string]interface{} var claims map[string]any
if err := authToken.Claims(&claims); err != nil { if err := authToken.Claims(&claims); err != nil {
return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("failed to parse claims: %w", err)} return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("failed to parse claims: %w", err)}
} }
return claims, nil return claims, nil
} }
+8 -1
View File
@@ -41,7 +41,10 @@ func Backup(catalystStorage *storage.Storage, c *database.Config, writer io.Writ
archive := zip.NewWriter(writer) archive := zip.NewWriter(writer)
defer archive.Close() defer archive.Close()
archive.SetComment(GetVersion()) err := archive.SetComment(GetVersion())
if err != nil {
return err
}
// S3 // S3
if err := backupS3(catalystStorage, archive); err != nil { if err := backupS3(catalystStorage, archive); err != nil {
@@ -86,6 +89,7 @@ func backupS3(catalystStorage *storage.Storage, archive *zip.Writer) error {
} }
} }
} }
return nil return nil
} }
@@ -105,6 +109,7 @@ func backupArango(c *database.Config, archive *zip.Writer) error {
func zipDump(dir string, archive *zip.Writer) error { func zipDump(dir string, archive *zip.Writer) error {
fsys := os.DirFS(dir) fsys := os.DirFS(dir)
return fs.WalkDir(fsys, ".", func(p string, d fs.DirEntry, err error) error { return fs.WalkDir(fsys, ".", func(p string, d fs.DirEntry, err error) error {
if err != nil { if err != nil {
return err return err
@@ -127,6 +132,7 @@ func zipDump(dir string, archive *zip.Writer) error {
if _, err := io.Copy(a, f); err != nil { if _, err := io.Copy(a, f); err != nil {
return err return err
} }
return nil return nil
}) })
} }
@@ -144,5 +150,6 @@ func arangodump(dir string, config *database.Config) error {
"--server.database", name, "--server.database", name,
} }
cmd := exec.Command("arangodump", args...) cmd := exec.Command("arangodump", args...)
return cmd.Run() return cmd.Run()
} }
+2 -1
View File
@@ -50,7 +50,7 @@ func New(c *Config) (*Bus, error) {
return &Bus{config: c, client: client}, err return &Bus{config: c, client: client}, err
} }
func (b *Bus) jsonPublish(msg interface{}, channel, key string) error { func (b *Bus) jsonPublish(msg any, channel, key string) error {
payload, err := json.Marshal(msg) payload, err := json.Marshal(msg)
if err != nil { if err != nil {
return err return err
@@ -65,5 +65,6 @@ func (b *Bus) safeSubscribe(key, channel string, handler func(c *emitter.Client,
log.Printf("Recovered %s in channel %s\n", r, channel) log.Printf("Recovered %s in channel %s\n", r, channel)
} }
}() }()
return b.client.Subscribe(key, channel, handler) return b.client.Subscribe(key, channel, handler)
} }
+1
View File
@@ -35,6 +35,7 @@ func (b *Bus) SubscribeDatabaseUpdate(f func(msg *DatabaseUpdateMsg)) error {
var msg DatabaseUpdateMsg var msg DatabaseUpdateMsg
if err := json.Unmarshal(m.Payload(), &msg); err != nil { if err := json.Unmarshal(m.Payload(), &msg); err != nil {
log.Println(err) log.Println(err)
return return
} }
go f(&msg) go f(&msg)
+2 -1
View File
@@ -18,7 +18,7 @@ type JobMsg struct {
Message *model.Message `json:"message"` Message *model.Message `json:"message"`
} }
func (b *Bus) PublishJob(id, automation string, payload interface{}, context *model.Context, origin *model.Origin) error { func (b *Bus) PublishJob(id, automation string, payload any, context *model.Context, origin *model.Origin) error {
return b.jsonPublish(&JobMsg{ return b.jsonPublish(&JobMsg{
ID: id, ID: id,
Automation: automation, Automation: automation,
@@ -35,6 +35,7 @@ func (b *Bus) SubscribeJob(f func(msg *JobMsg)) error {
var msg JobMsg var msg JobMsg
if err := json.Unmarshal(m.Payload(), &msg); err != nil { if err := json.Unmarshal(m.Payload(), &msg); err != nil {
log.Println(err) log.Println(err)
return return
} }
go f(&msg) go f(&msg)
+1
View File
@@ -29,6 +29,7 @@ func (b *Bus) SubscribeRequest(f func(msg *RequestMsg)) error {
msg := &RequestMsg{} msg := &RequestMsg{}
if err := json.Unmarshal(m.Payload(), msg); err != nil { if err := json.Unmarshal(m.Payload(), msg); err != nil {
log.Println(err) log.Println(err)
return return
} }
go f(msg) go f(msg)
+5 -4
View File
@@ -12,12 +12,12 @@ import (
const channelResult = "result" const channelResult = "result"
type ResultMsg struct { type ResultMsg struct {
Automation string `json:"automation"` Automation string `json:"automation"`
Data map[string]interface{} `json:"data,omitempty"` Data map[string]any `json:"data,omitempty"`
Target *model.Origin `json:"target"` Target *model.Origin `json:"target"`
} }
func (b *Bus) PublishResult(automation string, data map[string]interface{}, target *model.Origin) error { func (b *Bus) PublishResult(automation string, data map[string]any, target *model.Origin) error {
return b.jsonPublish(&ResultMsg{Automation: automation, Data: data, Target: target}, channelResult, b.config.resultBusKey) return b.jsonPublish(&ResultMsg{Automation: automation, Data: data, Target: target}, channelResult, b.config.resultBusKey)
} }
@@ -26,6 +26,7 @@ func (b *Bus) SubscribeResult(f func(msg *ResultMsg)) error {
msg := &ResultMsg{} msg := &ResultMsg{}
if err := json.Unmarshal(m.Payload(), msg); err != nil { if err := json.Unmarshal(m.Payload(), msg); err != nil {
log.Println(err) log.Println(err)
return return
} }
go f(msg) go f(msg)
+1 -1
View File
@@ -21,7 +21,6 @@ type busService struct {
} }
func New(apiURL, apikey, network string, catalystBus *bus.Bus, db *database.Database) error { func New(apiURL, apikey, network string, catalystBus *bus.Bus, db *database.Database) error {
h := &busService{db: db, apiURL: apiURL, apiKey: apikey, network: network, catalystBus: catalystBus} h := &busService{db: db, apiURL: apiURL, apiKey: apikey, network: network, catalystBus: catalystBus}
if err := catalystBus.SubscribeRequest(h.logRequest); err != nil { if err := catalystBus.SubscribeRequest(h.logRequest); err != nil {
@@ -40,6 +39,7 @@ func New(apiURL, apikey, network string, catalystBus *bus.Bus, db *database.Data
func busContext() context.Context { func busContext() context.Context {
// TODO: change roles? // TODO: change roles?
bot := &model.UserResponse{ID: "bot", Roles: []string{role.Admin}} bot := &model.UserResponse{ID: "bot", Roles: []string{role.Admin}}
return busdb.UserContext(context.Background(), bot) return busdb.UserContext(context.Background(), bot)
} }
+19 -4
View File
@@ -59,13 +59,15 @@ func pullImage(ctx context.Context, cli *client.Client, image string) (string, e
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
_, err = io.Copy(buf, reader) _, err = io.Copy(buf, reader)
return buf.String(), err return buf.String(), err
} }
func copyFile(ctx context.Context, cli *client.Client, path string, contentString string, id string) error { func copyFile(ctx context.Context, cli *client.Client, path string, contentString string, id string) error {
tarBuf := &bytes.Buffer{} tarBuf := &bytes.Buffer{}
tw := tar.NewWriter(tarBuf) tw := tar.NewWriter(tarBuf)
if err := tw.WriteHeader(&tar.Header{Name: path, Mode: 0755, Size: int64(len(contentString))}); err != nil { header := &tar.Header{Name: path, Mode: 0o755, Size: int64(len(contentString))}
if err := tw.WriteHeader(header); err != nil {
return err return err
} }
@@ -90,7 +92,12 @@ func runDocker(ctx context.Context, jobID, containerID string, db *database.Data
return nil, nil, err return nil, nil, err
} }
defer cli.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{Force: true}) defer func(cli *client.Client, ctx context.Context, containerID string, options types.ContainerRemoveOptions) {
err := cli.ContainerRemove(ctx, containerID, options)
if err != nil {
log.Println(err)
}
}(cli, ctx, containerID, types.ContainerRemoveOptions{Force: true})
if err := cli.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil { if err := cli.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil {
return nil, nil, err return nil, nil, err
@@ -123,13 +130,16 @@ func streamStdErr(ctx context.Context, cli *client.Client, jobID, containerID st
err := scanLines(ctx, jobID, containerLogs, stderrBuf, db) err := scanLines(ctx, jobID, containerLogs, stderrBuf, db)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
if err := containerLogs.Close(); err != nil { if err := containerLogs.Close(); err != nil {
log.Println(err) log.Println(err)
return return
} }
}() }()
return stderrBuf, nil return stderrBuf, nil
} }
@@ -139,24 +149,28 @@ func scanLines(ctx context.Context, jobID string, input io.ReadCloser, output io
_, err := stdcopy.StdCopy(w, w, input) _, err := stdcopy.StdCopy(w, w, input)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
log.Println(err) log.Println(err)
return return
} }
}() }()
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
b := s.Bytes() b := s.Bytes()
output.Write(b) _, _ = output.Write(b)
output.Write([]byte("\n")) _, _ = output.Write([]byte("\n"))
if err := db.JobLogAppend(ctx, jobID, string(b)+"\n"); err != nil { if err := db.JobLogAppend(ctx, jobID, string(b)+"\n"); err != nil {
log.Println(err) log.Println(err)
continue continue
} }
} }
return s.Err() return s.Err()
} }
@@ -172,6 +186,7 @@ func waitForContainer(ctx context.Context, cli *client.Client, containerID strin
return fmt.Errorf("container returned status code %d: stderr: %s", exitStatus.StatusCode, stderrBuf.String()) return fmt.Errorf("container returned status code %d: stderr: %s", exitStatus.StatusCode, stderrBuf.String())
} }
} }
return nil return nil
} }
+16 -5
View File
@@ -19,17 +19,20 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
}) })
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
automation, err := h.db.AutomationGet(ctx, automationMsg.Automation) automation, err := h.db.AutomationGet(ctx, automationMsg.Automation)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
if automation.Script == "" { if automation.Script == "" {
log.Println("automation is empty") log.Println("automation is empty")
return return
} }
@@ -39,11 +42,17 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
automationMsg.Message.Secrets["catalyst_apikey"] = h.apiKey automationMsg.Message.Secrets["catalyst_apikey"] = h.apiKey
automationMsg.Message.Secrets["catalyst_apiurl"] = h.apiURL automationMsg.Message.Secrets["catalyst_apiurl"] = h.apiURL
scriptMessage, _ := json.Marshal(automationMsg.Message) scriptMessage, err := json.Marshal(automationMsg.Message)
if err != nil {
log.Println(err)
return
}
containerID, logs, err := createContainer(ctx, automation.Image, automation.Script, string(scriptMessage), h.network) containerID, logs, err := createContainer(ctx, automation.Image, automation.Script, string(scriptMessage), h.network)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
@@ -55,18 +64,19 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
Status: job.Status, Status: job.Status,
}); err != nil { }); err != nil {
log.Println(err) log.Println(err)
return return
} }
var result map[string]interface{} var result map[string]any
stdout, _, err := runDocker(ctx, automationMsg.ID, containerID, h.db) stdout, _, err := runDocker(ctx, automationMsg.ID, containerID, h.db)
if err != nil { if err != nil {
result = map[string]interface{}{"error": fmt.Sprintf("error running script %s %s", err, string(stdout))} result = map[string]any{"error": fmt.Sprintf("error running script %s %s", err, string(stdout))}
} else { } else {
var data map[string]interface{} var data map[string]any
if err := json.Unmarshal(stdout, &data); err != nil { if err := json.Unmarshal(stdout, &data); err != nil {
result = map[string]interface{}{"error": string(stdout)} result = map[string]any{"error": string(stdout)}
} else { } else {
result = data result = data
} }
@@ -78,6 +88,7 @@ func (h *busService) handleJob(automationMsg *bus.JobMsg) {
if err := h.db.JobComplete(ctx, automationMsg.ID, result); err != nil { if err := h.db.JobComplete(ctx, automationMsg.ID, result); err != nil {
log.Println(err) log.Println(err)
return return
} }
} }
+14 -41
View File
@@ -8,7 +8,7 @@ import (
"github.com/SecurityBrewery/catalyst/generated/caql/parser" "github.com/SecurityBrewery/catalyst/generated/caql/parser"
) )
var TooComplexError = errors.New("unsupported features for index queries, use advanced search instead") var ErrTooComplex = errors.New("unsupported features for index queries, use advanced search instead")
type bleveBuilder struct { type bleveBuilder struct {
*parser.BaseCAQLParserListener *parser.BaseCAQLParserListener
@@ -35,8 +35,9 @@ func (s *bleveBuilder) pop() (n string) {
return return
} }
func (s *bleveBuilder) binaryPop() (interface{}, interface{}) { func (s *bleveBuilder) binaryPop() (any, any) {
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
return left, right return left, right
} }
@@ -48,9 +49,7 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
case ctx.Reference() != nil: case ctx.Reference() != nil:
// pass // pass
case ctx.Operator_unary() != nil: case ctx.Operator_unary() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_PLUS() != nil: case ctx.T_PLUS() != nil:
fallthrough fallthrough
case ctx.T_MINUS() != nil: case ctx.T_MINUS() != nil:
@@ -60,13 +59,9 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
case ctx.T_DIV() != nil: case ctx.T_DIV() != nil:
fallthrough fallthrough
case ctx.T_MOD() != nil: case ctx.T_MOD() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_RANGE() != nil: case ctx.T_RANGE() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_LT() != nil && ctx.GetEq_op() == nil: case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop() left, right := s.binaryPop()
s.push(fmt.Sprintf("%s:<%s", left, right)) s.push(fmt.Sprintf("%s:<%s", left, right))
@@ -79,64 +74,46 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
case ctx.T_GE() != nil && ctx.GetEq_op() == nil: case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop() left, right := s.binaryPop()
s.push(fmt.Sprintf("%s:>=%s", left, right)) s.push(fmt.Sprintf("%s:>=%s", left, right))
case ctx.T_IN() != nil && ctx.GetEq_op() == nil: case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil: case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop() left, right := s.binaryPop()
s.push(fmt.Sprintf("%s:%s", left, right)) s.push(fmt.Sprintf("%s:%s", left, right))
case ctx.T_NE() != nil && ctx.GetEq_op() == nil: case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop() left, right := s.binaryPop()
s.push(fmt.Sprintf("-%s:%s", left, right)) s.push(fmt.Sprintf("-%s:%s", left, right))
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil: case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
fallthrough fallthrough
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil: case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
fallthrough fallthrough
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil: case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil: case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
fallthrough fallthrough
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil: case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
fallthrough fallthrough
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil: case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_LIKE() != nil: case ctx.T_LIKE() != nil:
s.err = errors.New("index queries are like queries by default") s.err = errors.New("index queries are like queries by default")
return
case ctx.T_REGEX_MATCH() != nil: case ctx.T_REGEX_MATCH() != nil:
left, right := s.binaryPop() left, right := s.binaryPop()
if ctx.T_NOT() != nil { if ctx.T_NOT() != nil {
s.err = TooComplexError s.err = ErrTooComplex
return
} else { } else {
s.push(fmt.Sprintf("%s:/%s/", left, right)) s.push(fmt.Sprintf("%s:/%s/", left, right))
} }
case ctx.T_REGEX_NON_MATCH() != nil: case ctx.T_REGEX_NON_MATCH() != nil:
s.err = errors.New("index query cannot contain regex non matches, use advanced search instead") s.err = errors.New("index query cannot contain regex non matches, use advanced search instead")
return
case ctx.T_AND() != nil: case ctx.T_AND() != nil:
left, right := s.binaryPop() left, right := s.binaryPop()
s.push(fmt.Sprintf("%s %s", left, right)) s.push(fmt.Sprintf("%s %s", left, right))
case ctx.T_OR() != nil: case ctx.T_OR() != nil:
s.err = errors.New("index query cannot contain OR, use advanced search instead") s.err = errors.New("index query cannot contain OR, use advanced search instead")
return
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3: case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead") s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
return
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2: case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead") s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
return
default: default:
panic("unknown expression") panic("unknown expression")
} }
@@ -152,17 +129,13 @@ func (s *bleveBuilder) ExitReference(ctx *parser.ReferenceContext) {
case ctx.T_STRING() != nil: case ctx.T_STRING() != nil:
s.push(ctx.T_STRING().GetText()) s.push(ctx.T_STRING().GetText())
case ctx.Compound_value() != nil: case ctx.Compound_value() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.Function_call() != nil: case ctx.Function_call() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_OPEN() != nil: case ctx.T_OPEN() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
case ctx.T_ARRAY_OPEN() != nil: case ctx.T_ARRAY_OPEN() != nil:
s.err = TooComplexError s.err = ErrTooComplex
return
default: default:
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText())) panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
} }
+12 -2
View File
@@ -1,10 +1,14 @@
package caql package caql_test
import ( import (
"testing" "testing"
"github.com/SecurityBrewery/catalyst/caql"
) )
func TestBleveBuilder(t *testing.T) { func TestBleveBuilder(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
saql string saql string
@@ -18,15 +22,20 @@ func TestBleveBuilder(t *testing.T) {
{name: "Search 4", saql: `title == 'malware' AND 'wannacry'`, wantBleve: `title:"malware" "wannacry"`}, {name: "Search 4", saql: `title == 'malware' AND 'wannacry'`, wantBleve: `title:"malware" "wannacry"`},
} }
for _, tt := range tests { for _, tt := range tests {
parser := &Parser{} tt := tt
parser := &caql.Parser{}
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
expr, err := parser.Parse(tt.saql) expr, err := parser.Parse(tt.saql)
if (err != nil) != tt.wantParseErr { if (err != nil) != tt.wantParseErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
if expr != nil { if expr != nil {
t.Error(expr.String()) t.Error(expr.String())
} }
return return
} }
if err != nil { if err != nil {
@@ -37,6 +46,7 @@ func TestBleveBuilder(t *testing.T) {
if (err != nil) != tt.wantRebuildErr { if (err != nil) != tt.wantRebuildErr {
t.Error(expr.String()) t.Error(expr.String())
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
return return
} }
if err != nil { if err != nil {
+6 -1
View File
@@ -5,6 +5,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"golang.org/x/exp/slices"
"github.com/SecurityBrewery/catalyst/generated/caql/parser" "github.com/SecurityBrewery/catalyst/generated/caql/parser"
) )
@@ -40,6 +42,7 @@ func (s *aqlBuilder) pop() (n string) {
func (s *aqlBuilder) binaryPop() (string, string) { func (s *aqlBuilder) binaryPop() (string, string) {
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
return left, right return left, right
} }
@@ -181,8 +184,10 @@ func (s *aqlBuilder) toBoolString(v string) string {
if err != nil { if err != nil {
panic("invalid search " + err.Error()) panic("invalid search " + err.Error())
} }
return fmt.Sprintf(`d._key IN ["%s"]`, strings.Join(ids, `","`)) return fmt.Sprintf(`d._key IN ["%s"]`, strings.Join(ids, `","`))
} }
return v return v
} }
@@ -246,7 +251,7 @@ func (s *aqlBuilder) ExitFunction_call(ctx *parser.Function_callContext) {
} }
parameter := strings.Join(array, ", ") parameter := strings.Join(array, ", ")
if !stringSliceContains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) { if !slices.Contains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
panic("unknown function") panic("unknown function")
} }
+64 -72
View File
@@ -16,7 +16,6 @@ import (
func (s *aqlInterpreter) function(ctx *parser.Function_callContext) { func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
switch strings.ToUpper(ctx.T_STRING().GetText()) { switch strings.ToUpper(ctx.T_STRING().GetText()) {
default: default:
s.appendErrors(errors.New("unknown function")) s.appendErrors(errors.New("unknown function"))
@@ -26,8 +25,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
if len(ctx.AllExpression()) == 3 { if len(ctx.AllExpression()) == 3 {
u = s.pop().(bool) u = s.pop().(bool)
} }
seen := map[interface{}]bool{} seen := map[any]bool{}
values, anyArray := s.pop().([]interface{}), s.pop().([]interface{}) values, anyArray := s.pop().([]any), s.pop().([]any)
if u { if u {
for _, e := range anyArray { for _, e := range anyArray {
@@ -45,18 +44,18 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(anyArray) s.push(anyArray)
case "COUNT_DISTINCT", "COUNT_UNIQUE": case "COUNT_DISTINCT", "COUNT_UNIQUE":
count := 0 count := 0
seen := map[interface{}]bool{} seen := map[any]bool{}
array := s.pop().([]interface{}) array := s.pop().([]any)
for _, e := range array { for _, e := range array {
_, ok := seen[e] _, ok := seen[e]
if !ok { if !ok {
seen[e] = true seen[e] = true
count += 1 count++
} }
} }
s.push(float64(count)) s.push(float64(count))
case "FIRST": case "FIRST":
array := s.pop().([]interface{}) array := s.pop().([]any)
if len(array) == 0 { if len(array) == 0 {
s.push(nil) s.push(nil)
} else { } else {
@@ -65,16 +64,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "FLATTEN": // case "FLATTEN":
// case "INTERLEAVE": // case "INTERLEAVE":
case "INTERSECTION": case "INTERSECTION":
iset := New(s.pop().([]interface{})...) iset := NewSet(s.pop().([]any)...)
for i := 1; i < len(ctx.AllExpression()); i++ { for i := 1; i < len(ctx.AllExpression()); i++ {
iset = iset.Intersection(New(s.pop().([]interface{})...)) iset = iset.Intersection(NewSet(s.pop().([]any)...))
} }
s.push(iset.Values()) s.push(iset.Values())
// case "JACCARD": // case "JACCARD":
case "LAST": case "LAST":
array := s.pop().([]interface{}) array := s.pop().([]any)
if len(array) == 0 { if len(array) == 0 {
s.push(nil) s.push(nil)
} else { } else {
@@ -94,9 +93,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(float64(len(fmt.Sprint(v)))) s.push(float64(len(fmt.Sprint(v))))
case string: case string:
s.push(float64(utf8.RuneCountInString(v))) s.push(float64(utf8.RuneCountInString(v)))
case []interface{}: case []any:
s.push(float64(len(v))) s.push(float64(len(v)))
case map[string]interface{}: case map[string]any:
s.push(float64(len(v))) s.push(float64(len(v)))
default: default:
panic("unknown type") panic("unknown type")
@@ -104,7 +103,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "MINUS": case "MINUS":
var sets []*Set var sets []*Set
for i := 0; i < len(ctx.AllExpression()); i++ { for i := 0; i < len(ctx.AllExpression()); i++ {
sets = append(sets, New(s.pop().([]interface{})...)) sets = append(sets, NewSet(s.pop().([]any)...))
} }
iset := sets[len(sets)-1] iset := sets[len(sets)-1]
@@ -116,7 +115,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(iset.Values()) s.push(iset.Values())
case "NTH": case "NTH":
pos := s.pop().(float64) pos := s.pop().(float64)
array := s.pop().([]interface{}) array := s.pop().([]any)
if int(pos) >= len(array) || pos < 0 { if int(pos) >= len(array) || pos < 0 {
s.push(nil) s.push(nil)
} else { } else {
@@ -124,16 +123,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
// case "OUTERSECTION": // case "OUTERSECTION":
// array := s.pop().([]interface{}) // array := s.pop().([]interface{})
// union := New(array...) // union := NewSet(array...)
// intersection := New(s.pop().([]interface{})...) // intersection := NewSet(s.pop().([]interface{})...)
// for i := 1; i < len(ctx.AllExpression()); i++ { // for i := 1; i < len(ctx.AllExpression()); i++ {
// array = s.pop().([]interface{}) // array = s.pop().([]interface{})
// union = union.Union(New(array...)) // union = union.Union(NewSet(array...))
// intersection = intersection.Intersection(New(array...)) // intersection = intersection.Intersection(NewSet(array...))
// } // }
// s.push(union.Minus(intersection).Values()) // s.push(union.Minus(intersection).Values())
case "POP": case "POP":
array := s.pop().([]interface{}) array := s.pop().([]any)
s.push(array[:len(array)-1]) s.push(array[:len(array)-1])
case "POSITION", "CONTAINS_ARRAY": case "POSITION", "CONTAINS_ARRAY":
returnIndex := false returnIndex := false
@@ -141,7 +140,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
returnIndex = s.pop().(bool) returnIndex = s.pop().(bool)
} }
search := s.pop() search := s.pop()
array := s.pop().([]interface{}) array := s.pop().([]any)
for idx, e := range array { for idx, e := range array {
if e == search { if e == search {
@@ -164,7 +163,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
u = s.pop().(bool) u = s.pop().(bool)
} }
element := s.pop() element := s.pop()
array := s.pop().([]interface{}) array := s.pop().([]any)
if u && contains(array, element) { if u && contains(array, element) {
s.push(array) s.push(array)
@@ -173,13 +172,13 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
case "REMOVE_NTH": case "REMOVE_NTH":
position := s.pop().(float64) position := s.pop().(float64)
anyArray := s.pop().([]interface{}) anyArray := s.pop().([]any)
if position < 0 { if position < 0 {
position = float64(len(anyArray) + int(position)) position = float64(len(anyArray) + int(position))
} }
result := []interface{}{} result := []any{}
for idx, e := range anyArray { for idx, e := range anyArray {
if idx != int(position) { if idx != int(position) {
result = append(result, e) result = append(result, e)
@@ -193,7 +192,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
replaceValue := s.pop().(string) replaceValue := s.pop().(string)
position := s.pop().(float64) position := s.pop().(float64)
anyArray := s.pop().([]interface{}) anyArray := s.pop().([]any)
if position < 0 { if position < 0 {
position = float64(len(anyArray) + int(position)) position = float64(len(anyArray) + int(position))
@@ -224,8 +223,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
limit = s.pop().(float64) limit = s.pop().(float64)
} }
value := s.pop() value := s.pop()
array := s.pop().([]interface{}) array := s.pop().([]any)
result := []interface{}{} result := []any{}
for idx, e := range array { for idx, e := range array {
if e != value || float64(idx) > limit { if e != value || float64(idx) > limit {
result = append(result, e) result = append(result, e)
@@ -233,9 +232,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
s.push(result) s.push(result)
case "REMOVE_VALUES": case "REMOVE_VALUES":
values := s.pop().([]interface{}) values := s.pop().([]any)
array := s.pop().([]interface{}) array := s.pop().([]any)
result := []interface{}{} result := []any{}
for _, e := range array { for _, e := range array {
if !contains(values, e) { if !contains(values, e) {
result = append(result, e) result = append(result, e)
@@ -243,14 +242,14 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
s.push(result) s.push(result)
case "REVERSE": case "REVERSE":
array := s.pop().([]interface{}) array := s.pop().([]any)
var reverse []interface{} var reverse []any
for _, e := range array { for _, e := range array {
reverse = append([]interface{}{e}, reverse...) reverse = append([]any{e}, reverse...)
} }
s.push(reverse) s.push(reverse)
case "SHIFT": case "SHIFT":
s.push(s.pop().([]interface{})[1:]) s.push(s.pop().([]any)[1:])
case "SLICE": case "SLICE":
length := float64(-1) length := float64(-1)
full := true full := true
@@ -259,7 +258,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
full = false full = false
} }
start := int64(s.pop().(float64)) start := int64(s.pop().(float64))
array := s.pop().([]interface{}) array := s.pop().([]any)
if start < 0 { if start < 0 {
start = int64(len(array)) + start start = int64(len(array)) + start
@@ -276,43 +275,43 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
s.push(array[start:end]) s.push(array[start:end])
case "SORTED": case "SORTED":
array := s.pop().([]interface{}) array := s.pop().([]any)
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) }) sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
s.push(array) s.push(array)
case "SORTED_UNIQUE": case "SORTED_UNIQUE":
array := s.pop().([]interface{}) array := s.pop().([]any)
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) }) sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
s.push(unique(array)) s.push(unique(array))
case "UNION": case "UNION":
array := s.pop().([]interface{}) array := s.pop().([]any)
for i := 1; i < len(ctx.AllExpression()); i++ { for i := 1; i < len(ctx.AllExpression()); i++ {
array = append(array, s.pop().([]interface{})...) array = append(array, s.pop().([]any)...)
} }
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) }) sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
s.push(array) s.push(array)
case "UNION_DISTINCT": case "UNION_DISTINCT":
iset := New(s.pop().([]interface{})...) iset := NewSet(s.pop().([]any)...)
for i := 1; i < len(ctx.AllExpression()); i++ { for i := 1; i < len(ctx.AllExpression()); i++ {
iset = iset.Union(New(s.pop().([]interface{})...)) iset = iset.Union(NewSet(s.pop().([]any)...))
} }
s.push(unique(iset.Values())) s.push(unique(iset.Values()))
case "UNIQUE": case "UNIQUE":
s.push(unique(s.pop().([]interface{}))) s.push(unique(s.pop().([]any)))
case "UNSHIFT": case "UNSHIFT":
u := false u := false
if len(ctx.AllExpression()) == 3 { if len(ctx.AllExpression()) == 3 {
u = s.pop().(bool) u = s.pop().(bool)
} }
element := s.pop() element := s.pop()
array := s.pop().([]interface{}) array := s.pop().([]any)
if u && contains(array, element) { if u && contains(array, element) {
s.push(array) s.push(array)
} else { } else {
s.push(append([]interface{}{element}, array...)) s.push(append([]any{element}, array...))
} }
// Bit https://www.arangodb.com/docs/stable/aql/functions-bit.html // Bit https://www.arangodb.com/docs/stable/aql/functions-bit.html
@@ -367,8 +366,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
if len(ctx.AllExpression()) >= 2 { if len(ctx.AllExpression()) >= 2 {
removeInternal = s.pop().(bool) removeInternal = s.pop().(bool)
} }
var keys []interface{} var keys []any
for k := range s.pop().(map[string]interface{}) { for k := range s.pop().(map[string]any) {
isInternalKey := strings.HasPrefix(k, "_") isInternalKey := strings.HasPrefix(k, "_")
if !removeInternal || !isInternalKey { if !removeInternal || !isInternalKey {
keys = append(keys, k) keys = append(keys, k)
@@ -379,20 +378,20 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "COUNT": // case "COUNT":
case "HAS": case "HAS":
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
_, ok := left.(map[string]interface{})[right.(string)] _, ok := left.(map[string]any)[right.(string)]
s.push(ok) s.push(ok)
// case "KEEP": // case "KEEP":
// case "LENGTH": // case "LENGTH":
// case "MATCHES": // case "MATCHES":
case "MERGE": case "MERGE":
var docs []map[string]interface{} var docs []map[string]any
if len(ctx.AllExpression()) == 1 { if len(ctx.AllExpression()) == 1 {
for _, doc := range s.pop().([]interface{}) { for _, doc := range s.pop().([]any) {
docs = append([]map[string]interface{}{doc.(map[string]interface{})}, docs...) docs = append([]map[string]any{doc.(map[string]any)}, docs...)
} }
} else { } else {
for i := 0; i < len(ctx.AllExpression()); i++ { for i := 0; i < len(ctx.AllExpression()); i++ {
docs = append(docs, s.pop().(map[string]interface{})) docs = append(docs, s.pop().(map[string]any))
} }
} }
@@ -404,9 +403,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
} }
s.push(doc) s.push(doc)
case "MERGE_RECURSIVE": case "MERGE_RECURSIVE":
var doc map[string]interface{} var doc map[string]any
for i := 0; i < len(ctx.AllExpression()); i++ { for i := 0; i < len(ctx.AllExpression()); i++ {
err := mergo.Merge(&doc, s.pop().(map[string]interface{})) err := mergo.Merge(&doc, s.pop().(map[string]any))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -421,8 +420,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
if len(ctx.AllExpression()) == 2 { if len(ctx.AllExpression()) == 2 {
removeInternal = s.pop().(bool) removeInternal = s.pop().(bool)
} }
var values []interface{} var values []any
for k, v := range s.pop().(map[string]interface{}) { for k, v := range s.pop().(map[string]any) {
isInternalKey := strings.HasPrefix(k, "_") isInternalKey := strings.HasPrefix(k, "_")
if !removeInternal || !isInternalKey { if !removeInternal || !isInternalKey {
values = append(values, v) values = append(values, v)
@@ -458,10 +457,10 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "AVERAGE", "AVG": case "AVERAGE", "AVG":
count := 0 count := 0
sum := float64(0) sum := float64(0)
array := s.pop().([]interface{}) array := s.pop().([]any)
for _, element := range array { for _, element := range array {
if element != nil { if element != nil {
count += 1 count++
sum += toNumber(element) sum += toNumber(element)
} }
} }
@@ -506,7 +505,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "MAX": case "MAX":
var set bool var set bool
var max float64 var max float64
array := s.pop().([]interface{}) array := s.pop().([]any)
for _, element := range array { for _, element := range array {
if element != nil { if element != nil {
if !set || toNumber(element) > max { if !set || toNumber(element) > max {
@@ -521,7 +520,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(nil) s.push(nil)
} }
case "MEDIAN": case "MEDIAN":
array := s.pop().([]interface{}) array := s.pop().([]any)
var numbers []float64 var numbers []float64
for _, element := range array { for _, element := range array {
if f, ok := element.(float64); ok { if f, ok := element.(float64); ok {
@@ -544,7 +543,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "MIN": case "MIN":
var set bool var set bool
var min float64 var min float64
array := s.pop().([]interface{}) array := s.pop().([]any)
for _, element := range array { for _, element := range array {
if element != nil { if element != nil {
if !set || toNumber(element) < min { if !set || toNumber(element) < min {
@@ -566,7 +565,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(math.Pow(left.(float64), right.(float64))) s.push(math.Pow(left.(float64), right.(float64)))
case "PRODUCT": case "PRODUCT":
product := float64(1) product := float64(1)
array := s.pop().([]interface{}) array := s.pop().([]any)
for _, element := range array { for _, element := range array {
if element != nil { if element != nil {
product *= toNumber(element) product *= toNumber(element)
@@ -578,7 +577,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "RAND": case "RAND":
s.push(rand.Float64()) s.push(rand.Float64())
case "RANGE": case "RANGE":
var array []interface{} var array []any
var start, end, step float64 var start, end, step float64
if len(ctx.AllExpression()) == 2 { if len(ctx.AllExpression()) == 2 {
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
@@ -612,7 +611,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "STDDEV": // case "STDDEV":
case "SUM": case "SUM":
sum := float64(0) sum := float64(0)
array := s.pop().([]interface{}) array := s.pop().([]any)
for _, element := range array { for _, element := range array {
sum += toNumber(element) sum += toNumber(element)
} }
@@ -691,7 +690,6 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "IS_IPV4": // case "IS_IPV4":
// case "IS_KEY": // case "IS_KEY":
// case "TYPENAME": // case "TYPENAME":
} }
} }
@@ -705,6 +703,7 @@ func unique(array []interface{}) []interface{} {
filtered = append(filtered, e) filtered = append(filtered, e)
} }
} }
return filtered return filtered
} }
@@ -714,15 +713,7 @@ func contains(values []interface{}, e interface{}) bool {
return true return true
} }
} }
return false
}
func stringSliceContains(values []string, e string) bool {
for _, v := range values {
if e == v {
return true
}
}
return false return false
} }
@@ -747,4 +738,5 @@ var functionNames = []string{
"REGEX_REPLACE", "REVERSE", "RIGHT", "RTRIM", "SHA1", "SHA512", "SOUNDEX", "SPLIT", "STARTS_WITH", "SUBSTITUTE", "REGEX_REPLACE", "REVERSE", "RIGHT", "RTRIM", "SHA1", "SHA512", "SOUNDEX", "SPLIT", "STARTS_WITH", "SUBSTITUTE",
"SUBSTRING", "TOKENS", "TO_BASE64", "TO_HEX", "TRIM", "UPPER", "UUID", "TO_BOOL", "TO_NUMBER", "TO_STRING", "SUBSTRING", "TOKENS", "TO_BASE64", "TO_HEX", "TRIM", "UPPER", "UUID", "TO_BOOL", "TO_NUMBER", "TO_STRING",
"TO_ARRAY", "TO_LIST", "IS_NULL", "IS_BOOL", "IS_NUMBER", "IS_STRING", "IS_ARRAY", "IS_LIST", "IS_OBJECT", "TO_ARRAY", "TO_LIST", "IS_NULL", "IS_BOOL", "IS_NUMBER", "IS_STRING", "IS_ARRAY", "IS_LIST", "IS_OBJECT",
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME"} "IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME",
}
+27 -15
View File
@@ -1,18 +1,22 @@
package caql package caql_test
import ( import (
"encoding/json" "encoding/json"
"math" "math"
"reflect" "reflect"
"testing" "testing"
"github.com/SecurityBrewery/catalyst/caql"
) )
func TestFunctions(t *testing.T) { func TestFunctions(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
saql string saql string
wantRebuild string wantRebuild string
wantValue interface{} wantValue any
wantParseErr bool wantParseErr bool
wantRebuildErr bool wantRebuildErr bool
wantEvalErr bool wantEvalErr bool
@@ -266,13 +270,13 @@ func TestFunctions(t *testing.T) {
{name: "RADIANS", saql: `RADIANS(0)`, wantRebuild: `RADIANS(0)`, wantValue: 0}, {name: "RADIANS", saql: `RADIANS(0)`, wantRebuild: `RADIANS(0)`, wantValue: 0},
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.3503170117504508}, // {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.3503170117504508},
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.6138226173882478}, // {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.6138226173882478},
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []interface{}{float64(1), float64(2), float64(3), float64(4)}}, {name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []any{float64(1), float64(2), float64(3), float64(4)}},
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []interface{}{float64(1), float64(3)}}, {name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []any{float64(1), float64(3)}},
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []interface{}{float64(1), float64(4)}}, {name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []any{float64(1), float64(4)}},
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []interface{}{float64(1), float64(2)}}, {name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []any{float64(1), float64(2)}},
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []interface{}{1.5, 2.5}}, {name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []any{1.5, 2.5}},
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []interface{}{1.5, 2.0, 2.5}}, {name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []any{1.5, 2.0, 2.5}},
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []interface{}{-0.75, -0.25, 0.25, 0.75}}, {name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []any{-0.75, -0.25, 0.25, 0.75}},
{name: "ROUND", saql: `ROUND(2.49)`, wantRebuild: `ROUND(2.49)`, wantValue: 2}, {name: "ROUND", saql: `ROUND(2.49)`, wantRebuild: `ROUND(2.49)`, wantValue: 2},
{name: "ROUND", saql: `ROUND(2.50)`, wantRebuild: `ROUND(2.50)`, wantValue: 3}, {name: "ROUND", saql: `ROUND(2.50)`, wantRebuild: `ROUND(2.50)`, wantValue: 3},
{name: "ROUND", saql: `ROUND(-2.50)`, wantRebuild: `ROUND(-2.50)`, wantValue: -2}, {name: "ROUND", saql: `ROUND(-2.50)`, wantRebuild: `ROUND(-2.50)`, wantValue: -2},
@@ -299,15 +303,20 @@ func TestFunctions(t *testing.T) {
{name: "Function Error 3", saql: `ABS("abs")`, wantRebuild: `ABS("abs")`, wantEvalErr: true}, {name: "Function Error 3", saql: `ABS("abs")`, wantRebuild: `ABS("abs")`, wantEvalErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
parser := &Parser{} tt := tt
parser := &caql.Parser{}
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
expr, err := parser.Parse(tt.saql) expr, err := parser.Parse(tt.saql)
if (err != nil) != tt.wantParseErr { if (err != nil) != tt.wantParseErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
if expr != nil { if expr != nil {
t.Error(expr.String()) t.Error(expr.String())
} }
return return
} }
if err != nil { if err != nil {
@@ -318,6 +327,7 @@ func TestFunctions(t *testing.T) {
if (err != nil) != tt.wantRebuildErr { if (err != nil) != tt.wantRebuildErr {
t.Error(expr.String()) t.Error(expr.String())
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
return return
} }
if err != nil { if err != nil {
@@ -327,18 +337,19 @@ func TestFunctions(t *testing.T) {
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild) t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
} }
var myJson map[string]interface{} var myJSON map[string]any
if tt.values != "" { if tt.values != "" {
err = json.Unmarshal([]byte(tt.values), &myJson) err = json.Unmarshal([]byte(tt.values), &myJSON)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
value, err := expr.Eval(myJson) value, err := expr.Eval(myJSON)
if (err != nil) != tt.wantEvalErr { if (err != nil) != tt.wantEvalErr {
t.Error(expr.String()) t.Error(expr.String())
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
return return
} }
if err != nil { if err != nil {
@@ -367,14 +378,15 @@ func TestFunctions(t *testing.T) {
} }
} }
func jsonParse(s string) interface{} { func jsonParse(s string) any {
if s == "" { if s == "" {
return nil return nil
} }
var j interface{} var j any
err := json.Unmarshal([]byte(s), &j) err := json.Unmarshal([]byte(s), &j)
if err != nil { if err != nil {
panic(s + err.Error()) panic(s + err.Error())
} }
return j return j
} }
+32 -36
View File
@@ -10,22 +10,23 @@ import (
type aqlInterpreter struct { type aqlInterpreter struct {
*parser.BaseCAQLParserListener *parser.BaseCAQLParserListener
values map[string]interface{} values map[string]any
stack []interface{} stack []any
errs []error errs []error
} }
// push is a helper function for pushing new node to the listener Stack. // push is a helper function for pushing new node to the listener Stack.
func (s *aqlInterpreter) push(i interface{}) { func (s *aqlInterpreter) push(i any) {
s.stack = append(s.stack, i) s.stack = append(s.stack, i)
} }
// pop is a helper function for poping a node from the listener Stack. // pop is a helper function for poping a node from the listener Stack.
func (s *aqlInterpreter) pop() (n interface{}) { func (s *aqlInterpreter) pop() (n any) {
// Check that we have nodes in the stack. // Check that we have nodes in the stack.
size := len(s.stack) size := len(s.stack)
if size < 1 { if size < 1 {
s.appendErrors(ErrStack) s.appendErrors(ErrStack)
return return
} }
@@ -35,8 +36,9 @@ func (s *aqlInterpreter) pop() (n interface{}) {
return return
} }
func (s *aqlInterpreter) binaryPop() (interface{}, interface{}) { func (s *aqlInterpreter) binaryPop() (any, any) {
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
return left, right return left, right
} }
@@ -54,17 +56,14 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
s.push(plus(s.binaryPop())) s.push(plus(s.binaryPop()))
case ctx.T_MINUS() != nil: case ctx.T_MINUS() != nil:
s.push(minus(s.binaryPop())) s.push(minus(s.binaryPop()))
case ctx.T_TIMES() != nil: case ctx.T_TIMES() != nil:
s.push(times(s.binaryPop())) s.push(times(s.binaryPop()))
case ctx.T_DIV() != nil: case ctx.T_DIV() != nil:
s.push(div(s.binaryPop())) s.push(div(s.binaryPop()))
case ctx.T_MOD() != nil: case ctx.T_MOD() != nil:
s.push(mod(s.binaryPop())) s.push(mod(s.binaryPop()))
case ctx.T_RANGE() != nil: case ctx.T_RANGE() != nil:
s.push(aqlrange(s.binaryPop())) s.push(aqlrange(s.binaryPop()))
case ctx.T_LT() != nil && ctx.GetEq_op() == nil: case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
s.push(lt(s.binaryPop())) s.push(lt(s.binaryPop()))
case ctx.T_GT() != nil && ctx.GetEq_op() == nil: case ctx.T_GT() != nil && ctx.GetEq_op() == nil:
@@ -73,35 +72,30 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
s.push(le(s.binaryPop())) s.push(le(s.binaryPop()))
case ctx.T_GE() != nil && ctx.GetEq_op() == nil: case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
s.push(ge(s.binaryPop())) s.push(ge(s.binaryPop()))
case ctx.T_IN() != nil && ctx.GetEq_op() == nil: case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
s.push(maybeNot(ctx, in(s.binaryPop()))) s.push(maybeNot(ctx, in(s.binaryPop())))
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil: case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
s.push(eq(s.binaryPop())) s.push(eq(s.binaryPop()))
case ctx.T_NE() != nil && ctx.GetEq_op() == nil: case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
s.push(ne(s.binaryPop())) s.push(ne(s.binaryPop()))
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil: case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(all(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right)) s.push(all(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil: case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(any(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right)) s.push(anyElement(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil: case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(none(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right)) s.push(none(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil: case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(all(left.([]interface{}), in, right)) s.push(all(left.([]any), in, right))
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil: case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(any(left.([]interface{}), in, right)) s.push(anyElement(left.([]any), in, right))
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil: case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(none(left.([]interface{}), in, right)) s.push(none(left.([]any), in, right))
case ctx.T_LIKE() != nil: case ctx.T_LIKE() != nil:
m, err := like(s.binaryPop()) m, err := like(s.binaryPop())
s.appendErrors(err) s.appendErrors(err)
@@ -114,21 +108,18 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
m, err := regexNonMatch(s.binaryPop()) m, err := regexNonMatch(s.binaryPop())
s.appendErrors(err) s.appendErrors(err)
s.push(maybeNot(ctx, m)) s.push(maybeNot(ctx, m))
case ctx.T_AND() != nil: case ctx.T_AND() != nil:
s.push(and(s.binaryPop())) s.push(and(s.binaryPop()))
case ctx.T_OR() != nil: case ctx.T_OR() != nil:
s.push(or(s.binaryPop())) s.push(or(s.binaryPop()))
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3: case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
right, middle, left := s.pop(), s.pop(), s.pop() right, middle, left := s.pop(), s.pop(), s.pop()
s.push(ternary(left, middle, right)) s.push(ternary(left, middle, right))
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2: case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
right, left := s.pop(), s.pop() right, left := s.pop(), s.pop()
s.push(ternary(left, nil, right)) s.push(ternary(left, nil, right))
default: default:
panic("unkown expression") panic("unknown expression")
} }
} }
@@ -159,7 +150,7 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
case ctx.DOT() != nil: case ctx.DOT() != nil:
reference := s.pop() reference := s.pop()
s.push(reference.(map[string]interface{})[ctx.T_STRING().GetText()]) s.push(reference.(map[string]any)[ctx.T_STRING().GetText()])
case ctx.T_STRING() != nil: case ctx.T_STRING() != nil:
s.push(s.getVar(ctx.T_STRING().GetText())) s.push(s.getVar(ctx.T_STRING().GetText()))
case ctx.Compound_value() != nil: case ctx.Compound_value() != nil:
@@ -175,14 +166,15 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
if f, ok := key.(float64); ok { if f, ok := key.(float64); ok {
index := int(f) index := int(f)
if index < 0 { if index < 0 {
index = len(reference.([]interface{})) + index index = len(reference.([]any)) + index
} }
s.push(reference.([]interface{})[index]) s.push(reference.([]any)[index])
return return
} }
s.push(reference.(map[string]interface{})[key.(string)]) s.push(reference.(map[string]any)[key.(string)])
default: default:
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText())) panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
} }
@@ -239,17 +231,17 @@ func (s *aqlInterpreter) ExitValue_literal(ctx *parser.Value_literalContext) {
// ExitArray is called when production array is exited. // ExitArray is called when production array is exited.
func (s *aqlInterpreter) ExitArray(ctx *parser.ArrayContext) { func (s *aqlInterpreter) ExitArray(ctx *parser.ArrayContext) {
array := []interface{}{} array := []any{}
for range ctx.AllExpression() { for range ctx.AllExpression() {
// prepend element // prepend element
array = append([]interface{}{s.pop()}, array...) array = append([]any{s.pop()}, array...)
} }
s.push(array) s.push(array)
} }
// ExitObject is called when production object is exited. // ExitObject is called when production object is exited.
func (s *aqlInterpreter) ExitObject(ctx *parser.ObjectContext) { func (s *aqlInterpreter) ExitObject(ctx *parser.ObjectContext) {
object := map[string]interface{}{} object := map[string]any{}
for range ctx.AllObject_element() { for range ctx.AllObject_element() {
key, value := s.pop(), s.pop() key, value := s.pop(), s.pop()
@@ -290,7 +282,7 @@ func (s *aqlInterpreter) ExitObject_element_name(ctx *parser.Object_element_name
} }
} }
func (s *aqlInterpreter) getVar(identifier string) interface{} { func (s *aqlInterpreter) getVar(identifier string) any {
v, ok := s.values[identifier] v, ok := s.values[identifier]
if !ok { if !ok {
s.appendErrors(ErrUndefined) s.appendErrors(ErrUndefined)
@@ -303,10 +295,11 @@ func maybeNot(ctx *parser.ExpressionContext, m bool) bool {
if ctx.T_NOT() != nil { if ctx.T_NOT() != nil {
return !m return !m
} }
return m return m
} }
func getOp(tokenType int) func(left, right interface{}) bool { func getOp(tokenType int) func(left, right any) bool {
switch tokenType { switch tokenType {
case parser.CAQLLexerT_EQ: case parser.CAQLLexerT_EQ:
return eq return eq
@@ -323,33 +316,36 @@ func getOp(tokenType int) func(left, right interface{}) bool {
case parser.CAQLLexerT_IN: case parser.CAQLLexerT_IN:
return in return in
default: default:
panic("unkown token type") panic("unknown token type")
} }
} }
func all(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool { func all(slice []any, op func(any, any) bool, expr any) bool {
for _, e := range slice { for _, e := range slice {
if !op(e, expr) { if !op(e, expr) {
return false return false
} }
} }
return true return true
} }
func any(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool { func anyElement(slice []any, op func(any, any) bool, expr any) bool {
for _, e := range slice { for _, e := range slice {
if op(e, expr) { if op(e, expr) {
return true return true
} }
} }
return false return false
} }
func none(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool { func none(slice []any, op func(any, any) bool, expr any) bool {
for _, e := range slice { for _, e := range slice {
if op(e, expr) { if op(e, expr) {
return false return false
} }
} }
return true return true
} }
+94 -71
View File
@@ -10,21 +10,23 @@ import (
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators // Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
func or(left, right interface{}) interface{} { func or(left, right any) any {
if toBool(left) { if toBool(left) {
return left return left
} }
return right return right
} }
func and(left, right interface{}) interface{} { func and(left, right any) any {
if !toBool(left) { if !toBool(left) {
return left return left
} }
return right return right
} }
func toBool(i interface{}) bool { func toBool(i any) bool {
switch v := i.(type) { switch v := i.(type) {
case nil: case nil:
return false return false
@@ -36,9 +38,9 @@ func toBool(i interface{}) bool {
return v != 0 return v != 0
case string: case string:
return v != "" return v != ""
case []interface{}: case []any:
return true return true
case map[string]interface{}: case map[string]any:
return true return true
default: default:
panic("bool conversion failed") panic("bool conversion failed")
@@ -47,15 +49,15 @@ func toBool(i interface{}) bool {
// Arithmetic operators https://www.arangodb.com/docs/3.7/aql/operators.html#arithmetic-operators // Arithmetic operators https://www.arangodb.com/docs/3.7/aql/operators.html#arithmetic-operators
func plus(left, right interface{}) float64 { func plus(left, right any) float64 {
return toNumber(left) + toNumber(right) return toNumber(left) + toNumber(right)
} }
func minus(left, right interface{}) float64 { func minus(left, right any) float64 {
return toNumber(left) - toNumber(right) return toNumber(left) - toNumber(right)
} }
func times(left, right interface{}) float64 { func times(left, right any) float64 {
return round(toNumber(left) * toNumber(right)) return round(toNumber(left) * toNumber(right))
} }
@@ -63,19 +65,20 @@ func round(r float64) float64 {
return math.Round(r*100000) / 100000 return math.Round(r*100000) / 100000
} }
func div(left, right interface{}) float64 { func div(left, right any) float64 {
b := toNumber(right) b := toNumber(right)
if b == 0 { if b == 0 {
return 0 return 0
} }
return round(toNumber(left) / b) return round(toNumber(left) / b)
} }
func mod(left, right interface{}) float64 { func mod(left, right any) float64 {
return math.Mod(toNumber(left), toNumber(right)) return math.Mod(toNumber(left), toNumber(right))
} }
func toNumber(i interface{}) float64 { func toNumber(i any) float64 {
switch v := i.(type) { switch v := i.(type) {
case nil: case nil:
return 0 return 0
@@ -83,6 +86,7 @@ func toNumber(i interface{}) float64 {
if v { if v {
return 1 return 1
} }
return 0 return 0
case float64: case float64:
switch { switch {
@@ -91,22 +95,25 @@ func toNumber(i interface{}) float64 {
case math.IsInf(v, 0): case math.IsInf(v, 0):
return 0 return 0
} }
return v return v
case string: case string:
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil { if err != nil {
return 0 return 0
} }
return f return f
case []interface{}: case []any:
if len(v) == 0 { if len(v) == 0 {
return 0 return 0
} }
if len(v) == 1 { if len(v) == 1 {
return toNumber(v[0]) return toNumber(v[0])
} }
return 0 return 0
case map[string]interface{}: case map[string]any:
return 0 return 0
default: default:
panic("number conversion error") panic("number conversion error")
@@ -116,7 +123,7 @@ func toNumber(i interface{}) float64 {
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators // Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
// Order https://www.arangodb.com/docs/3.7/aql/fundamentals-type-value-order.html // Order https://www.arangodb.com/docs/3.7/aql/fundamentals-type-value-order.html
func eq(left, right interface{}) bool { func eq(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right) leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV { if leftV != rightV {
return false return false
@@ -126,15 +133,15 @@ func eq(left, right interface{}) bool {
return true return true
case bool, float64, string: case bool, float64, string:
return left == right return left == right
case []interface{}: case []any:
ra := right.([]interface{}) ra := right.([]any)
max := len(l) max := len(l)
if len(ra) > max { if len(ra) > max {
max = len(ra) max = len(ra)
} }
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if len(l) > i { if len(l) > i {
li = l[i] li = l[i]
} }
@@ -146,13 +153,14 @@ func eq(left, right interface{}) bool {
return false return false
} }
} }
return true return true
case map[string]interface{}: case map[string]any:
ro := right.(map[string]interface{}) ro := right.(map[string]any)
for _, key := range keys(l, ro) { for _, key := range keys(l, ro) {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if lv, ok := l[key]; ok { if lv, ok := l[key]; ok {
li = lv li = lv
} }
@@ -164,17 +172,18 @@ func eq(left, right interface{}) bool {
return false return false
} }
} }
return true return true
default: default:
panic("unknown type") panic("unknown type")
} }
} }
func ne(left, right interface{}) bool { func ne(left, right any) bool {
return !eq(left, right) return !eq(left, right)
} }
func lt(left, right interface{}) bool { func lt(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right) leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV { if leftV != rightV {
return leftV < rightV return leftV < rightV
@@ -190,15 +199,15 @@ func lt(left, right interface{}) bool {
return l < right.(float64) return l < right.(float64)
case string: case string:
return l < right.(string) return l < right.(string)
case []interface{}: case []any:
ra := right.([]interface{}) ra := right.([]any)
max := len(l) max := len(l)
if len(ra) > max { if len(ra) > max {
max = len(ra) max = len(ra)
} }
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if len(l) > i { if len(l) > i {
li = l[i] li = l[i]
} }
@@ -210,13 +219,14 @@ func lt(left, right interface{}) bool {
return lt(li, rai) return lt(li, rai)
} }
} }
return false return false
case map[string]interface{}: case map[string]any:
ro := right.(map[string]interface{}) ro := right.(map[string]any)
for _, key := range keys(l, ro) { for _, key := range keys(l, ro) {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if lv, ok := l[key]; ok { if lv, ok := l[key]; ok {
li = lv li = lv
} }
@@ -228,16 +238,17 @@ func lt(left, right interface{}) bool {
return lt(li, rai) return lt(li, rai)
} }
} }
return false return false
default: default:
panic("unknown type") panic("unknown type")
} }
} }
func keys(l map[string]interface{}, ro map[string]interface{}) []string { func keys(l map[string]any, ro map[string]any) []string {
var keys []string var keys []string
seen := map[string]bool{} seen := map[string]bool{}
for _, a := range []map[string]interface{}{l, ro} { for _, a := range []map[string]any{l, ro} {
for k := range a { for k := range a {
if _, ok := seen[k]; !ok { if _, ok := seen[k]; !ok {
seen[k] = true seen[k] = true
@@ -246,10 +257,11 @@ func keys(l map[string]interface{}, ro map[string]interface{}) []string {
} }
} }
sort.Strings(keys) sort.Strings(keys)
return keys return keys
} }
func gt(left, right interface{}) bool { func gt(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right) leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV { if leftV != rightV {
return leftV > rightV return leftV > rightV
@@ -265,15 +277,15 @@ func gt(left, right interface{}) bool {
return l > right.(float64) return l > right.(float64)
case string: case string:
return l > right.(string) return l > right.(string)
case []interface{}: case []any:
ra := right.([]interface{}) ra := right.([]any)
max := len(l) max := len(l)
if len(ra) > max { if len(ra) > max {
max = len(ra) max = len(ra)
} }
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if len(l) > i { if len(l) > i {
li = l[i] li = l[i]
} }
@@ -285,13 +297,14 @@ func gt(left, right interface{}) bool {
return gt(li, rai) return gt(li, rai)
} }
} }
return false return false
case map[string]interface{}: case map[string]any:
ro := right.(map[string]interface{}) ro := right.(map[string]any)
for _, key := range keys(l, ro) { for _, key := range keys(l, ro) {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if lv, ok := l[key]; ok { if lv, ok := l[key]; ok {
li = lv li = lv
} }
@@ -303,13 +316,14 @@ func gt(left, right interface{}) bool {
return gt(li, rai) return gt(li, rai)
} }
} }
return false return false
default: default:
panic("unknown type") panic("unknown type")
} }
} }
func le(left, right interface{}) bool { func le(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right) leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV { if leftV != rightV {
return leftV <= rightV return leftV <= rightV
@@ -325,15 +339,15 @@ func le(left, right interface{}) bool {
return l <= right.(float64) return l <= right.(float64)
case string: case string:
return l <= right.(string) return l <= right.(string)
case []interface{}: case []any:
ra := right.([]interface{}) ra := right.([]any)
max := len(l) max := len(l)
if len(ra) > max { if len(ra) > max {
max = len(ra) max = len(ra)
} }
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if len(l) > i { if len(l) > i {
li = l[i] li = l[i]
} }
@@ -345,13 +359,14 @@ func le(left, right interface{}) bool {
return le(li, rai) return le(li, rai)
} }
} }
return true return true
case map[string]interface{}: case map[string]any:
ro := right.(map[string]interface{}) ro := right.(map[string]any)
for _, key := range keys(l, ro) { for _, key := range keys(l, ro) {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if lv, ok := l[key]; ok { if lv, ok := l[key]; ok {
li = lv li = lv
} }
@@ -363,13 +378,14 @@ func le(left, right interface{}) bool {
return lt(li, rai) return lt(li, rai)
} }
} }
return true return true
default: default:
panic("unknown type") panic("unknown type")
} }
} }
func ge(left, right interface{}) bool { func ge(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right) leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV { if leftV != rightV {
return leftV >= rightV return leftV >= rightV
@@ -385,15 +401,15 @@ func ge(left, right interface{}) bool {
return l >= right.(float64) return l >= right.(float64)
case string: case string:
return l >= right.(string) return l >= right.(string)
case []interface{}: case []any:
ra := right.([]interface{}) ra := right.([]any)
max := len(l) max := len(l)
if len(ra) > max { if len(ra) > max {
max = len(ra) max = len(ra)
} }
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if len(l) > i { if len(l) > i {
li = l[i] li = l[i]
} }
@@ -405,13 +421,14 @@ func ge(left, right interface{}) bool {
return ge(li, rai) return ge(li, rai)
} }
} }
return true return true
case map[string]interface{}: case map[string]any:
ro := right.(map[string]interface{}) ro := right.(map[string]any)
for _, key := range keys(l, ro) { for _, key := range keys(l, ro) {
var li interface{} = nil var li any
var rai interface{} = nil var rai any
if lv, ok := l[key]; ok { if lv, ok := l[key]; ok {
li = lv li = lv
} }
@@ -423,14 +440,15 @@ func ge(left, right interface{}) bool {
return gt(li, rai) return gt(li, rai)
} }
} }
return true return true
default: default:
panic("unknown type") panic("unknown type")
} }
} }
func in(left, right interface{}) bool { func in(left, right any) bool {
a, ok := right.([]interface{}) a, ok := right.([]any)
if !ok { if !ok {
return false return false
} }
@@ -439,23 +457,25 @@ func in(left, right interface{}) bool {
return true return true
} }
} }
return false return false
} }
func like(left, right interface{}) (bool, error) { func like(left, right any) (bool, error) {
return match(right.(string), left.(string)) return match(right.(string), left.(string))
} }
func regexMatch(left, right interface{}) (bool, error) { func regexMatch(left, right any) (bool, error) {
return regexp.Match(right.(string), []byte(left.(string))) return regexp.Match(right.(string), []byte(left.(string)))
} }
func regexNonMatch(left, right interface{}) (bool, error) { func regexNonMatch(left, right any) (bool, error) {
m, err := regexp.Match(right.(string), []byte(left.(string))) m, err := regexp.Match(right.(string), []byte(left.(string)))
return !m, err return !m, err
} }
func typeValue(v interface{}) int { func typeValue(v any) int {
switch v.(type) { switch v.(type) {
case nil: case nil:
return 0 return 0
@@ -465,9 +485,9 @@ func typeValue(v interface{}) int {
return 2 return 2
case string: case string:
return 3 return 3
case []interface{}: case []any:
return 4 return 4
case map[string]interface{}: case map[string]any:
return 5 return 5
default: default:
panic("unknown type") panic("unknown type")
@@ -476,22 +496,25 @@ func typeValue(v interface{}) int {
// Ternary operator https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator // Ternary operator https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
func ternary(left, middle, right interface{}) interface{} { func ternary(left, middle, right any) any {
if toBool(left) { if toBool(left) {
if middle != nil { if middle != nil {
return middle return middle
} }
return left return left
} }
return right return right
} }
// Range operators https://www.arangodb.com/docs/3.7/aql/operators.html#range-operator // Range operators https://www.arangodb.com/docs/3.7/aql/operators.html#range-operator
func aqlrange(left, right interface{}) []float64 { func aqlrange(left, right any) []float64 {
var v []float64 var v []float64
for i := int(left.(float64)); i <= int(right.(float64)); i++ { for i := int(left.(float64)); i <= int(right.(float64)); i++ {
v = append(v, float64(i)) v = append(v, float64(i))
} }
return v return v
} }
+4 -3
View File
@@ -21,7 +21,7 @@ func (p *Parser) Parse(aql string) (t *Tree, err error) {
err = fmt.Errorf("%s", r) err = fmt.Errorf("%s", r)
} }
}() }()
// Setup the input // Set up the input
inputStream := antlr.NewInputStream(aql) inputStream := antlr.NewInputStream(aql)
errorListener := &errorListener{} errorListener := &errorListener{}
@@ -52,7 +52,7 @@ type Tree struct {
prefix string prefix string
} }
func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) { func (t *Tree) Eval(values map[string]any) (i any, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err = fmt.Errorf("%s", r) err = fmt.Errorf("%s", r)
@@ -65,6 +65,7 @@ func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
if interpreter.errs != nil { if interpreter.errs != nil {
return nil, interpreter.errs[0] return nil, interpreter.errs[0]
} }
return interpreter.stack[0], nil return interpreter.stack[0], nil
} }
@@ -103,7 +104,7 @@ type errorListener struct {
errs []error errs []error
} }
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
el.errs = append(el.errs, fmt.Errorf("line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg)) el.errs = append(el.errs, fmt.Errorf("line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg))
} }
+32 -22
View File
@@ -1,9 +1,11 @@
package caql package caql_test
import ( import (
"encoding/json" "encoding/json"
"reflect" "reflect"
"testing" "testing"
"github.com/SecurityBrewery/catalyst/caql"
) )
type MockSearcher struct{} type MockSearcher struct{}
@@ -13,11 +15,13 @@ func (m MockSearcher) Search(_ string) (ids []string, err error) {
} }
func TestParseSAQLEval(t *testing.T) { func TestParseSAQLEval(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
saql string saql string
wantRebuild string wantRebuild string
wantValue interface{} wantValue any
wantParseErr bool wantParseErr bool
wantRebuildErr bool wantRebuildErr bool
wantEvalErr bool wantEvalErr bool
@@ -89,15 +93,15 @@ func TestParseSAQLEval(t *testing.T) {
// {name: "String 9", saql: `'this is a longer string.'`, wantRebuild: `"this is a longer string."`, wantValue: "this is a longer string."}, // {name: "String 9", saql: `'this is a longer string.'`, wantRebuild: `"this is a longer string."`, wantValue: "this is a longer string."},
// {name: "String 10", saql: `'the path separator on Windows is \\'`, wantRebuild: `"the path separator on Windows is \\"`, wantValue: `the path separator on Windows is \`}, // {name: "String 10", saql: `'the path separator on Windows is \\'`, wantRebuild: `"the path separator on Windows is \\"`, wantValue: `the path separator on Windows is \`},
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []interface{}{}}, {name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []any{}},
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []interface{}{true}}, {name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []any{true}},
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}}, {name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
{ {
name: "Array 4", saql: `[-99, "yikes!", [false, ["no"], []], 1]`, wantRebuild: `[-99, "yikes!", [false, ["no"], []], 1]`, name: "Array 4", saql: `[-99, "yikes!", [false, ["no"], []], 1]`, wantRebuild: `[-99, "yikes!", [false, ["no"], []], 1]`,
wantValue: []interface{}{-99.0, "yikes!", []interface{}{false, []interface{}{"no"}, []interface{}{}}, float64(1)}, wantValue: []any{-99.0, "yikes!", []any{false, []any{"no"}, []any{}}, float64(1)},
}, },
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []interface{}{[]interface{}{"fox", "marshal"}}}, {name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []any{[]any{"fox", "marshal"}}},
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}}, {name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
{name: "Array Error 1", saql: "(1,2,3)", wantParseErr: true}, {name: "Array Error 1", saql: "(1,2,3)", wantParseErr: true},
{name: "Array Access 1", saql: "u.friends[0]", wantRebuild: "u.friends[0]", wantValue: 7, values: `{"u": {"friends": [7,8,9]}}`}, {name: "Array Access 1", saql: "u.friends[0]", wantRebuild: "u.friends[0]", wantValue: 7, values: `{"u": {"friends": [7,8,9]}}`},
@@ -105,14 +109,14 @@ func TestParseSAQLEval(t *testing.T) {
{name: "Array Access 3", saql: "u.friends[-1]", wantRebuild: "u.friends[-1]", wantValue: 9, values: `{"u": {"friends": [7,8,9]}}`}, {name: "Array Access 3", saql: "u.friends[-1]", wantRebuild: "u.friends[-1]", wantValue: 9, values: `{"u": {"friends": [7,8,9]}}`},
{name: "Array Access 4", saql: "u.friends[-2]", wantRebuild: "u.friends[-2]", wantValue: 8, values: `{"u": {"friends": [7,8,9]}}`}, {name: "Array Access 4", saql: "u.friends[-2]", wantRebuild: "u.friends[-2]", wantValue: 8, values: `{"u": {"friends": [7,8,9]}}`},
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]interface{}{}}, {name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]any{}},
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}}, {name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]interface{}{"a": float64(1)}}, {name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]any{"a": float64(1)}},
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}}, {name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]interface{}{"return": float64(1)}}, {name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]any{"return": float64(1)}},
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]interface{}{"return": float64(1)}}, {name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]any{"return": float64(1)}},
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}}, {name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}}, {name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
// {"Object 8", "{`return`: 1}", `{"return": 1}`, true}, // {"Object 8", "{`return`: 1}", `{"return": 1}`, true},
// {"Object 7", "{´return´: 1}", `{"return": 1}`, true}, // {"Object 7", "{´return´: 1}", `{"return": 1}`, true},
{name: "Object Error 1: return is a keyword", saql: `{like: 1}`, wantParseErr: true}, {name: "Object Error 1: return is a keyword", saql: `{like: 1}`, wantParseErr: true},
@@ -272,7 +276,7 @@ func TestParseSAQLEval(t *testing.T) {
{name: "Arithmetic 17", saql: `23 * {}`, wantRebuild: `23 * {}`, wantValue: 0}, {name: "Arithmetic 17", saql: `23 * {}`, wantRebuild: `23 * {}`, wantValue: 0},
{name: "Arithmetic 18", saql: `5 * [7]`, wantRebuild: `5 * [7]`, wantValue: 35}, {name: "Arithmetic 18", saql: `5 * [7]`, wantRebuild: `5 * [7]`, wantValue: 35},
{name: "Arithmetic 19", saql: `24 / "12"`, wantRebuild: `24 / "12"`, wantValue: 2}, {name: "Arithmetic 19", saql: `24 / "12"`, wantRebuild: `24 / "12"`, wantValue: 2},
{name: "Arithmetic Error 1: Divison by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0}, {name: "Arithmetic Error 1: Division by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
// https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator // https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
{name: "Ternary 1", saql: `u.age > 15 || u.active == true ? u.userId : null`, wantRebuild: `u.age > 15 OR u.active == true ? u.userId : null`, wantValue: 45, values: `{"u": {"active": true, "age": 2, "userId": 45}}`}, {name: "Ternary 1", saql: `u.age > 15 || u.active == true ? u.userId : null`, wantRebuild: `u.age > 15 OR u.active == true ? u.userId : null`, wantValue: 45, values: `{"u": {"active": true, "age": 2, "userId": 45}}`},
@@ -287,20 +291,24 @@ func TestParseSAQLEval(t *testing.T) {
{name: "Security 2", saql: `doc.value == 1 || true INSERT {foo: "bar"} IN collection //`, wantParseErr: true}, {name: "Security 2", saql: `doc.value == 1 || true INSERT {foo: "bar"} IN collection //`, wantParseErr: true},
// https://www.arangodb.com/docs/3.7/aql/operators.html#operator-precedence // https://www.arangodb.com/docs/3.7/aql/operators.html#operator-precedence
{name: "Precendence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false}, {name: "Precedence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
} }
for _, tt := range tests { for _, tt := range tests {
parser := &Parser{ tt := tt
parser := &caql.Parser{
Searcher: &MockSearcher{}, Searcher: &MockSearcher{},
} }
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
expr, err := parser.Parse(tt.saql) expr, err := parser.Parse(tt.saql)
if (err != nil) != tt.wantParseErr { if (err != nil) != tt.wantParseErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
if expr != nil { if expr != nil {
t.Error(expr.String()) t.Error(expr.String())
} }
return return
} }
if err != nil { if err != nil {
@@ -311,6 +319,7 @@ func TestParseSAQLEval(t *testing.T) {
if (err != nil) != tt.wantRebuildErr { if (err != nil) != tt.wantRebuildErr {
t.Error(expr.String()) t.Error(expr.String())
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
return return
} }
if err != nil { if err != nil {
@@ -320,18 +329,19 @@ func TestParseSAQLEval(t *testing.T) {
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild) t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
} }
var myJson map[string]interface{} var myJSON map[string]any
if tt.values != "" { if tt.values != "" {
err = json.Unmarshal([]byte(tt.values), &myJson) err = json.Unmarshal([]byte(tt.values), &myJSON)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
value, err := expr.Eval(myJson) value, err := expr.Eval(myJSON)
if (err != nil) != tt.wantEvalErr { if (err != nil) != tt.wantEvalErr {
t.Error(expr.String()) t.Error(expr.String())
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr) t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
return return
} }
if err != nil { if err != nil {
+21 -36
View File
@@ -22,19 +22,18 @@
package caql package caql
import "sort" import (
"sort"
type (
Set struct {
hash map[interface{}]nothing
}
nothing struct{}
) )
// Create a new set type Set struct {
func New(initial ...interface{}) *Set { hash map[any]nothing
s := &Set{make(map[interface{}]nothing)} }
type nothing struct{}
func NewSet(initial ...any) *Set {
s := &Set{make(map[any]nothing)}
for _, v := range initial { for _, v := range initial {
s.Insert(v) s.Insert(v)
@@ -43,9 +42,8 @@ func New(initial ...interface{}) *Set {
return s return s
} }
// Find the difference between two sets
func (s *Set) Difference(set *Set) *Set { func (s *Set) Difference(set *Set) *Set {
n := make(map[interface{}]nothing) n := make(map[any]nothing)
for k := range s.hash { for k := range s.hash {
if _, exists := set.hash[k]; !exists { if _, exists := set.hash[k]; !exists {
@@ -56,27 +54,18 @@ func (s *Set) Difference(set *Set) *Set {
return &Set{n} return &Set{n}
} }
// Call f for each item in the set func (s *Set) Has(element any) bool {
func (s *Set) Do(f func(interface{})) {
for k := range s.hash {
f(k)
}
}
// Test to see whether or not the element is in the set
func (s *Set) Has(element interface{}) bool {
_, exists := s.hash[element] _, exists := s.hash[element]
return exists return exists
} }
// Add an element to the set func (s *Set) Insert(element any) {
func (s *Set) Insert(element interface{}) {
s.hash[element] = nothing{} s.hash[element] = nothing{}
} }
// Find the intersection of two sets
func (s *Set) Intersection(set *Set) *Set { func (s *Set) Intersection(set *Set) *Set {
n := make(map[interface{}]nothing) n := make(map[any]nothing)
for k := range s.hash { for k := range s.hash {
if _, exists := set.hash[k]; exists { if _, exists := set.hash[k]; exists {
@@ -87,23 +76,20 @@ func (s *Set) Intersection(set *Set) *Set {
return &Set{n} return &Set{n}
} }
// Return the number of items in the set
func (s *Set) Len() int { func (s *Set) Len() int {
return len(s.hash) return len(s.hash)
} }
// Test whether or not this set is a proper subset of "set"
func (s *Set) ProperSubsetOf(set *Set) bool { func (s *Set) ProperSubsetOf(set *Set) bool {
return s.SubsetOf(set) && s.Len() < set.Len() return s.SubsetOf(set) && s.Len() < set.Len()
} }
// Remove an element from the set func (s *Set) Remove(element any) {
func (s *Set) Remove(element interface{}) {
delete(s.hash, element) delete(s.hash, element)
} }
func (s *Set) Minus(set *Set) *Set { func (s *Set) Minus(set *Set) *Set {
n := make(map[interface{}]nothing) n := make(map[any]nothing)
for k := range s.hash { for k := range s.hash {
n[k] = nothing{} n[k] = nothing{}
} }
@@ -115,7 +101,6 @@ func (s *Set) Minus(set *Set) *Set {
return &Set{n} return &Set{n}
} }
// Test whether or not this set is a subset of "set"
func (s *Set) SubsetOf(set *Set) bool { func (s *Set) SubsetOf(set *Set) bool {
if s.Len() > set.Len() { if s.Len() > set.Len() {
return false return false
@@ -125,12 +110,12 @@ func (s *Set) SubsetOf(set *Set) bool {
return false return false
} }
} }
return true return true
} }
// Find the union of two sets
func (s *Set) Union(set *Set) *Set { func (s *Set) Union(set *Set) *Set {
n := make(map[interface{}]nothing) n := make(map[any]nothing)
for k := range s.hash { for k := range s.hash {
n[k] = nothing{} n[k] = nothing{}
@@ -142,8 +127,8 @@ func (s *Set) Union(set *Set) *Set {
return &Set{n} return &Set{n}
} }
func (s *Set) Values() []interface{} { func (s *Set) Values() []any {
values := []interface{}{} values := []any{}
for k := range s.hash { for k := range s.hash {
values = append(values, k) values = append(values, k)
+6 -5
View File
@@ -27,7 +27,9 @@ import (
) )
func Test(t *testing.T) { func Test(t *testing.T) {
s := New() t.Parallel()
s := NewSet()
s.Insert(5) s.Insert(5)
@@ -50,8 +52,8 @@ func Test(t *testing.T) {
} }
// Difference // Difference
s1 := New(1, 2, 3, 4, 5, 6) s1 := NewSet(1, 2, 3, 4, 5, 6)
s2 := New(4, 5, 6) s2 := NewSet(4, 5, 6)
s3 := s1.Difference(s2) s3 := s1.Difference(s2)
if s3.Len() != 3 { if s3.Len() != 3 {
@@ -73,7 +75,7 @@ func Test(t *testing.T) {
} }
// Union // Union
s4 := New(7, 8, 9) s4 := NewSet(7, 8, 9)
s3 = s2.Union(s4) s3 = s2.Union(s4)
if s3.Len() != 6 { if s3.Len() != 6 {
@@ -92,5 +94,4 @@ func Test(t *testing.T) {
if s1.ProperSubsetOf(s1) { if s1.ProperSubsetOf(s1) {
t.Errorf("set should not be a subset of itself") t.Errorf("set should not be a subset of itself")
} }
} }
+3
View File
@@ -39,8 +39,10 @@ func unquote(s string) (string, error) {
buf = append(buf, s[i]) buf = append(buf, s[i])
} }
} }
return string(buf), nil return string(buf), nil
} }
return s, nil return s, nil
} }
if quote != '"' && quote != '\'' { if quote != '"' && quote != '\'' {
@@ -75,5 +77,6 @@ func unquote(s string) (string, error) {
buf = append(buf, runeTmp[:n]...) buf = append(buf, runeTmp[:n]...)
} }
} }
return string(buf), nil return string(buf), nil
} }
+13 -12
View File
@@ -8,26 +8,25 @@
package caql package caql
import ( import (
"errors"
"strconv" "strconv"
"testing" "testing"
) )
type quoteTest struct { type quoteTest struct {
in string in string
out string out string
ascii string
graphic string
} }
var quotetests = []quoteTest{ var quotetests = []quoteTest{
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`, ascii: `"\a\b\f\r\n\t\v"`, graphic: `"\a\b\f\r\n\t\v"`}, {in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`},
{"\\", `"\\"`, `"\\"`, `"\\"`}, {"\\", `"\\"`},
{"abc\xffdef", `"abc\xffdef"`, `"abc\xffdef"`, `"abc\xffdef"`}, {"abc\xffdef", `"abc\xffdef"`},
{"\u263a", `"☺"`, `"\u263a"`, `"☺"`}, {"\u263a", `"☺"`},
{"\U0010ffff", `"\U0010ffff"`, `"\U0010ffff"`, `"\U0010ffff"`}, {"\U0010ffff", `"\U0010ffff"`},
{"\x04", `"\x04"`, `"\x04"`, `"\x04"`}, {"\x04", `"\x04"`},
// Some non-printable but graphic runes. Final column is double-quoted. // Some non-printable but graphic runes. Final column is double-quoted.
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`, `"!\u00a0!\u2000!\u3000!"`, "\"!\u00a0!\u2000!\u3000!\""}, {"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`},
} }
type unQuoteTest struct { type unQuoteTest struct {
@@ -104,6 +103,8 @@ var misquoted = []string{
} }
func TestUnquote(t *testing.T) { func TestUnquote(t *testing.T) {
t.Parallel()
for _, tt := range unquotetests { for _, tt := range unquotetests {
if out, err := unquote(tt.in); err != nil || out != tt.out { if out, err := unquote(tt.in); err != nil || out != tt.out {
t.Errorf("unquote(%#q) = %q, %v want %q, nil", tt.in, out, err, tt.out) t.Errorf("unquote(%#q) = %q, %v want %q, nil", tt.in, out, err, tt.out)
@@ -118,7 +119,7 @@ func TestUnquote(t *testing.T) {
} }
for _, s := range misquoted { for _, s := range misquoted {
if out, err := unquote(s); out != "" || err != strconv.ErrSyntax { if out, err := unquote(s); out != "" || !errors.Is(err, strconv.ErrSyntax) {
t.Errorf("unquote(%#q) = %q, %v want %q, %v", s, out, err, "", strconv.ErrSyntax) t.Errorf("unquote(%#q) = %q, %v want %q, %v", s, out, err, "", strconv.ErrSyntax)
} }
} }
+7 -3
View File
@@ -48,6 +48,7 @@ Pattern:
// using the star // using the star
if ok && (len(t) == 0 || len(pattern) > 0) { if ok && (len(t) == 0 || len(pattern) > 0) {
name = t name = t
continue continue
} }
if err != nil { if err != nil {
@@ -64,6 +65,7 @@ Pattern:
continue continue
} }
name = t name = t
continue Pattern continue Pattern
} }
if err != nil { if err != nil {
@@ -79,8 +81,10 @@ Pattern:
return false, err return false, err
} }
} }
return false, nil return false, nil
} }
return len(name) == 0, nil return len(name) == 0, nil
} }
@@ -104,6 +108,7 @@ Scan:
break Scan break Scan
} }
} }
return star, pattern[0:i], pattern[i:] return star, pattern[0:i], pattern[i:]
} }
@@ -120,7 +125,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
failed = true failed = true
} }
switch chunk[0] { switch chunk[0] {
case '_': case '_':
if !failed { if !failed {
if s[0] == '/' { if s[0] == '/' {
@@ -130,14 +134,13 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
s = s[n:] s = s[n:]
} }
chunk = chunk[1:] chunk = chunk[1:]
case '\\': case '\\':
chunk = chunk[1:] chunk = chunk[1:]
if len(chunk) == 0 { if len(chunk) == 0 {
return "", false, ErrBadPattern return "", false, ErrBadPattern
} }
fallthrough
fallthrough
default: default:
if !failed { if !failed {
if chunk[0] != s[0] { if chunk[0] != s[0] {
@@ -151,5 +154,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
if failed { if failed {
return "", false, nil return "", false, nil
} }
return s, true, nil return s, true, nil
} }
+7 -2
View File
@@ -7,7 +7,10 @@
package caql package caql
import "testing" import (
"errors"
"testing"
)
type MatchTest struct { type MatchTest struct {
pattern, s string pattern, s string
@@ -41,9 +44,11 @@ var matchTests = []MatchTest{
} }
func TestMatch(t *testing.T) { func TestMatch(t *testing.T) {
t.Parallel()
for _, tt := range matchTests { for _, tt := range matchTests {
ok, err := match(tt.pattern, tt.s) ok, err := match(tt.pattern, tt.s)
if ok != tt.match || err != tt.err { if ok != tt.match || !errors.Is(err, tt.err) {
t.Errorf("match(%#q, %#q) = %v, %v want %v, %v", tt.pattern, tt.s, ok, err, tt.match, tt.err) t.Errorf("match(%#q, %#q) = %v, %v want %v, %v", tt.pattern, tt.s, ok, err, tt.match, tt.err)
} }
} }
+3 -13
View File
@@ -4,6 +4,7 @@ import (
"github.com/alecthomas/kong" "github.com/alecthomas/kong"
kongyaml "github.com/alecthomas/kong-yaml" kongyaml "github.com/alecthomas/kong-yaml"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/exp/slices"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/SecurityBrewery/catalyst" "github.com/SecurityBrewery/catalyst"
@@ -62,7 +63,7 @@ func MapConfig(cli CLI) (*catalyst.Config, error) {
roles = append(roles, role.Explodes(cli.AuthDefaultRoles)...) roles = append(roles, role.Explodes(cli.AuthDefaultRoles)...)
roles = role.Explodes(role.Strings(roles)) roles = role.Explodes(role.Strings(roles))
scopes := unique(append([]string{oidc.ScopeOpenID, "profile", "email"}, cli.OIDCScopes...)) scopes := slices.Compact(append([]string{oidc.ScopeOpenID, "profile", "email"}, cli.OIDCScopes...))
config := &catalyst.Config{ config := &catalyst.Config{
IndexPath: cli.IndexPath, IndexPath: cli.IndexPath,
Network: cli.Network, Network: cli.Network,
@@ -83,17 +84,6 @@ func MapConfig(cli CLI) (*catalyst.Config, error) {
Bus: &bus.Config{Host: cli.EmitterIOHost, Key: cli.EmitterIORKey, APIUrl: cli.CatalystAddress + "/api"}, Bus: &bus.Config{Host: cli.EmitterIOHost, Key: cli.EmitterIORKey, APIUrl: cli.CatalystAddress + "/api"},
InitialAPIKey: cli.InitialAPIKey, InitialAPIKey: cli.InitialAPIKey,
} }
return config, nil return config, nil
} }
func unique(l []string) []string {
keys := make(map[string]bool)
var list []string
for _, entry := range l {
if _, value := keys[entry]; !value {
keys[entry] = true
list = append(list, entry)
}
}
return list
}
+12 -4
View File
@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
) )
@@ -22,15 +23,21 @@ func stateCookie(r *http.Request) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return stateCookie.Value, nil return stateCookie.Value, nil
} }
func setClaimsCookie(w http.ResponseWriter, claims map[string]interface{}) { func setClaimsCookie(w http.ResponseWriter, claims map[string]any) {
b, _ := json.Marshal(claims) b, err := json.Marshal(claims)
if err != nil {
log.Println(err)
return
}
http.SetCookie(w, &http.Cookie{Name: userSessionCookie, Value: base64.StdEncoding.EncodeToString(b)}) http.SetCookie(w, &http.Cookie{Name: userSessionCookie, Value: base64.StdEncoding.EncodeToString(b)})
} }
func claimsCookie(r *http.Request) (map[string]interface{}, bool, error) { func claimsCookie(r *http.Request) (map[string]any, bool, error) {
userCookie, err := r.Cookie(userSessionCookie) userCookie, err := r.Cookie(userSessionCookie)
if err != nil { if err != nil {
return nil, true, nil return nil, true, nil
@@ -41,9 +48,10 @@ func claimsCookie(r *http.Request) (map[string]interface{}, bool, error) {
return nil, false, fmt.Errorf("could not decode cookie: %w", err) return nil, false, fmt.Errorf("could not decode cookie: %w", err)
} }
var claims map[string]interface{} var claims map[string]any
if err := json.Unmarshal(b, &claims); err != nil { if err := json.Unmarshal(b, &claims); err != nil {
return nil, false, errors.New("claims not in session") return nil, false, errors.New("claims not in session")
} }
return claims, false, err return claims, false, err
} }
+10 -10
View File
@@ -25,6 +25,9 @@ package dag
import ( import (
"errors" "errors"
"sort" "sort"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
) )
type Graph struct { type Graph struct {
@@ -52,6 +55,7 @@ func (g *Graph) AddNode(name string) error {
} }
g.outputs[name] = make(map[string]struct{}) g.outputs[name] = make(map[string]struct{})
g.inputs[name] = 0 g.inputs[name] = 0
return nil return nil
} }
@@ -61,6 +65,7 @@ func (g *Graph) AddNodes(names ...string) error {
return err return err
} }
} }
return nil return nil
} }
@@ -101,7 +106,9 @@ func (g *Graph) Toposort() ([]string, error) {
L = append(L, n) L = append(L, n)
ms := make([]string, len(outputs[n])) ms := make([]string, len(outputs[n]))
for _, k := range keys(outputs[n]) { keys := maps.Keys(outputs[n])
slices.Sort(keys)
for _, k := range keys {
m := k m := k
// i := outputs[n][m] // i := outputs[n][m]
// ms[i-1] = m // ms[i-1] = m
@@ -130,15 +137,6 @@ func (g *Graph) Toposort() ([]string, error) {
return L, nil return L, nil
} }
func keys(m map[string]struct{}) []string {
var keys []string
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func (g *Graph) GetParents(id string) []string { func (g *Graph) GetParents(id string) []string {
var parents []string var parents []string
for node, targets := range g.outputs { for node, targets := range g.outputs {
@@ -147,6 +145,7 @@ func (g *Graph) GetParents(id string) []string {
} }
} }
sort.Strings(parents) sort.Strings(parents)
return parents return parents
} }
@@ -160,5 +159,6 @@ func (g *Graph) GetRoot() (string, error) {
if len(roots) != 1 { if len(roots) != 1 {
return "", errors.New("more than one root") return "", errors.New("more than one root")
} }
return roots[0], nil return roots[0], nil
} }
+42 -24
View File
@@ -20,23 +20,17 @@
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package dag package dag_test
import ( import (
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) "golang.org/x/exp/slices"
func index(s []string, v string) int { "github.com/SecurityBrewery/catalyst/dag"
for i, s := range s { )
if s == v {
return i
}
}
return -1
}
type Edge struct { type Edge struct {
From string From string
@@ -44,13 +38,17 @@ type Edge struct {
} }
func TestDuplicatedNode(t *testing.T) { func TestDuplicatedNode(t *testing.T) {
graph := NewGraph() t.Parallel()
graph := dag.NewGraph()
assert.NoError(t, graph.AddNode("a")) assert.NoError(t, graph.AddNode("a"))
assert.Error(t, graph.AddNode("a")) assert.Error(t, graph.AddNode("a"))
} }
func TestWikipedia(t *testing.T) { func TestWikipedia(t *testing.T) {
graph := NewGraph() t.Parallel()
graph := dag.NewGraph()
assert.NoError(t, graph.AddNodes("2", "3", "5", "7", "8", "9", "10", "11")) assert.NoError(t, graph.AddNodes("2", "3", "5", "7", "8", "9", "10", "11"))
edges := []Edge{ edges := []Edge{
@@ -79,27 +77,30 @@ func TestWikipedia(t *testing.T) {
} }
for _, e := range edges { for _, e := range edges {
if i, j := index(result, e.From), index(result, e.To); i > j { if i, j := slices.Index(result, e.From), slices.Index(result, e.To); i > j {
t.Errorf("dependency failed: not satisfy %v(%v) > %v(%v)", e.From, i, e.To, j) t.Errorf("dependency failed: not satisfy %v(%v) > %v(%v)", e.From, i, e.To, j)
} }
} }
} }
func TestCycle(t *testing.T) { func TestCycle(t *testing.T) {
graph := NewGraph() t.Parallel()
graph := dag.NewGraph()
assert.NoError(t, graph.AddNodes("1", "2", "3")) assert.NoError(t, graph.AddNodes("1", "2", "3"))
assert.NoError(t, graph.AddEdge("1", "2")) assert.NoError(t, graph.AddEdge("1", "2"))
assert.NoError(t, graph.AddEdge("2", "3")) assert.NoError(t, graph.AddEdge("2", "3"))
assert.NoError(t, graph.AddEdge("3", "1")) assert.NoError(t, graph.AddEdge("3", "1"))
_, err := graph.Toposort() if _, err := graph.Toposort(); err == nil {
if err == nil {
t.Errorf("closed path not detected in closed pathed graph") t.Errorf("closed path not detected in closed pathed graph")
} }
} }
func TestGraph_GetParents(t *testing.T) { func TestGraph_GetParents(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
nodes []string nodes []string
edges map[string]string edges map[string]string
@@ -117,8 +118,11 @@ func TestGraph_GetParents(t *testing.T) {
{"parents 3", fields{nodes: []string{"1", "2", "3"}, edges: map[string]string{"1": "3", "2": "3"}}, args{id: "3"}, []string{"1", "2"}}, {"parents 3", fields{nodes: []string{"1", "2", "3"}, edges: map[string]string{"1": "3", "2": "3"}}, args{id: "3"}, []string{"1", "2"}},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
g := NewGraph() t.Parallel()
g := dag.NewGraph()
for _, node := range tt.fields.nodes { for _, node := range tt.fields.nodes {
assert.NoError(t, g.AddNode(node)) assert.NoError(t, g.AddNode(node))
} }
@@ -134,7 +138,9 @@ func TestGraph_GetParents(t *testing.T) {
} }
func TestDAG_AddNode(t *testing.T) { func TestDAG_AddNode(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
v := "1" v := "1"
assert.NoError(t, dag.AddNode(v)) assert.NoError(t, dag.AddNode(v))
@@ -143,7 +149,9 @@ func TestDAG_AddNode(t *testing.T) {
} }
func TestDAG_AddEdge(t *testing.T) { func TestDAG_AddEdge(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
assert.NoError(t, dag.AddNode("0")) assert.NoError(t, dag.AddNode("0"))
assert.NoError(t, dag.AddNode("1")) assert.NoError(t, dag.AddNode("1"))
assert.NoError(t, dag.AddNode("2")) assert.NoError(t, dag.AddNode("2"))
@@ -162,7 +170,9 @@ func TestDAG_AddEdge(t *testing.T) {
} }
func TestDAG_GetParents(t *testing.T) { func TestDAG_GetParents(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
assert.NoError(t, dag.AddNode("1")) assert.NoError(t, dag.AddNode("1"))
assert.NoError(t, dag.AddNode("2")) assert.NoError(t, dag.AddNode("2"))
assert.NoError(t, dag.AddNode("3")) assert.NoError(t, dag.AddNode("3"))
@@ -176,7 +186,9 @@ func TestDAG_GetParents(t *testing.T) {
} }
func TestDAG_GetDescendants(t *testing.T) { func TestDAG_GetDescendants(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
assert.NoError(t, dag.AddNode("1")) assert.NoError(t, dag.AddNode("1"))
assert.NoError(t, dag.AddNode("2")) assert.NoError(t, dag.AddNode("2"))
assert.NoError(t, dag.AddNode("3")) assert.NoError(t, dag.AddNode("3"))
@@ -188,7 +200,9 @@ func TestDAG_GetDescendants(t *testing.T) {
} }
func TestDAG_Topsort(t *testing.T) { func TestDAG_Topsort(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
assert.NoError(t, dag.AddNode("1")) assert.NoError(t, dag.AddNode("1"))
assert.NoError(t, dag.AddNode("2")) assert.NoError(t, dag.AddNode("2"))
assert.NoError(t, dag.AddNode("3")) assert.NoError(t, dag.AddNode("3"))
@@ -203,7 +217,9 @@ func TestDAG_Topsort(t *testing.T) {
} }
func TestDAG_TopsortStable(t *testing.T) { func TestDAG_TopsortStable(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
assert.NoError(t, dag.AddNode("1")) assert.NoError(t, dag.AddNode("1"))
assert.NoError(t, dag.AddNode("2")) assert.NoError(t, dag.AddNode("2"))
assert.NoError(t, dag.AddNode("3")) assert.NoError(t, dag.AddNode("3"))
@@ -216,7 +232,9 @@ func TestDAG_TopsortStable(t *testing.T) {
} }
func TestDAG_TopsortStable2(t *testing.T) { func TestDAG_TopsortStable2(t *testing.T) {
dag := NewGraph() t.Parallel()
dag := dag.NewGraph()
assert.NoError(t, dag.AddNodes("block-ioc", "block-iocs", "block-sender", "board", "fetch-iocs", "escalate", "extract-iocs", "mail-available", "search-email-gateway")) assert.NoError(t, dag.AddNodes("block-ioc", "block-iocs", "block-sender", "board", "fetch-iocs", "escalate", "extract-iocs", "mail-available", "search-email-gateway"))
assert.NoError(t, dag.AddEdge("block-iocs", "block-ioc")) assert.NoError(t, dag.AddEdge("block-iocs", "block-ioc"))
+6 -4
View File
@@ -23,7 +23,7 @@ func (db *Database) ArtifactGet(ctx context.Context, id int64, name string) (*mo
FOR a in NOT_NULL(d.artifacts, []) FOR a in NOT_NULL(d.artifacts, [])
FILTER a.name == @name FILTER a.name == @name
RETURN a` RETURN a`
cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]interface{}{ cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]any{
"@collection": TicketCollectionName, "@collection": TicketCollectionName,
"ID": fmt.Sprint(id), "ID": fmt.Sprint(id),
"name": name, "name": name,
@@ -55,7 +55,8 @@ func (db *Database) ArtifactUpdate(ctx context.Context, id int64, name string, a
LET newartifacts = APPEND(REMOVE_VALUE(d.artifacts, a), @artifact) LET newartifacts = APPEND(REMOVE_VALUE(d.artifacts, a), @artifact)
UPDATE d WITH { "artifacts": newartifacts } IN @@collection UPDATE d WITH { "artifacts": newartifacts } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"@collection": TicketCollectionName, "@collection": TicketCollectionName,
"ID": id, "ID": id,
"name": name, "name": name,
@@ -69,7 +70,7 @@ func (db *Database) ArtifactUpdate(ctx context.Context, id int64, name string, a
} }
func (db *Database) EnrichArtifact(ctx context.Context, id int64, name string, enrichmentForm *model.EnrichmentForm) (*model.TicketWithTickets, error) { func (db *Database) EnrichArtifact(ctx context.Context, id int64, name string, enrichmentForm *model.EnrichmentForm) (*model.TicketWithTickets, error) {
enrichment := model.Enrichment{time.Now().UTC(), enrichmentForm.Data, enrichmentForm.Name} enrichment := model.Enrichment{Created: time.Now().UTC(), Data: enrichmentForm.Data, Name: enrichmentForm.Name}
ticketFilterQuery, ticketFilterVars, err := db.Hooks.TicketWriteFilter(ctx) ticketFilterQuery, ticketFilterVars, err := db.Hooks.TicketWriteFilter(ctx)
if err != nil { if err != nil {
@@ -85,7 +86,8 @@ func (db *Database) EnrichArtifact(ctx context.Context, id int64, name string, e
LET newartifacts = APPEND(REMOVE_VALUE(d.artifacts, a), MERGE(a, { "enrichments": newenrichments })) LET newartifacts = APPEND(REMOVE_VALUE(d.artifacts, a), MERGE(a, { "enrichments": newenrichments }))
UPDATE d WITH { "artifacts": newartifacts } IN @@collection UPDATE d WITH { "artifacts": newartifacts } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"@collection": TicketCollectionName, "@collection": TicketCollectionName,
"ID": id, "ID": id,
"name": name, "name": name,
+3 -2
View File
@@ -10,7 +10,7 @@ import (
"github.com/SecurityBrewery/catalyst/generated/model" "github.com/SecurityBrewery/catalyst/generated/model"
) )
func toAutomation(doc *model.AutomationForm) interface{} { func toAutomation(doc *model.AutomationForm) *model.Automation {
return &model.Automation{ return &model.Automation{
Image: doc.Image, Image: doc.Image,
Script: doc.Script, Script: doc.Script,
@@ -72,12 +72,13 @@ func (db *Database) AutomationUpdate(ctx context.Context, id string, automation
func (db *Database) AutomationDelete(ctx context.Context, id string) error { func (db *Database) AutomationDelete(ctx context.Context, id string) error {
_, err := db.automationCollection.RemoveDocument(ctx, id) _, err := db.automationCollection.RemoveDocument(ctx, id)
return err return err
} }
func (db *Database) AutomationList(ctx context.Context) ([]*model.AutomationResponse, error) { func (db *Database) AutomationList(ctx context.Context) ([]*model.AutomationResponse, error) {
query := "FOR d IN @@collection SORT d._key ASC RETURN UNSET(d, 'script')" query := "FOR d IN @@collection SORT d._key ASC RETURN UNSET(d, 'script')"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": AutomationCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": AutomationCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+24 -18
View File
@@ -40,10 +40,12 @@ type Operation struct {
Ids []driver.DocumentID Ids []driver.DocumentID
} }
var CreateOperation = &Operation{Type: bus.DatabaseEntryCreated} var (
var ReadOperation = &Operation{Type: bus.DatabaseEntryRead} CreateOperation = &Operation{Type: bus.DatabaseEntryCreated}
ReadOperation = &Operation{Type: bus.DatabaseEntryRead}
)
func (db BusDatabase) Query(ctx context.Context, query string, vars map[string]interface{}, operation *Operation) (cur driver.Cursor, logs *model.LogEntry, err error) { func (db *BusDatabase) Query(ctx context.Context, query string, vars map[string]any, operation *Operation) (cur driver.Cursor, logs *model.LogEntry, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
cur, err = db.internal.Query(ctx, query, vars) cur, err = db.internal.Query(ctx, query, vars)
@@ -61,31 +63,31 @@ func (db BusDatabase) Query(ctx context.Context, query string, vars map[string]i
return cur, logs, err return cur, logs, err
} }
func (db BusDatabase) Remove(ctx context.Context) (err error) { func (db *BusDatabase) Remove(ctx context.Context) (err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
return db.internal.Remove(ctx) return db.internal.Remove(ctx)
} }
func (db BusDatabase) Collection(ctx context.Context, name string) (col driver.Collection, err error) { func (db *BusDatabase) Collection(ctx context.Context, name string) (col driver.Collection, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
return db.internal.Collection(ctx, name) return db.internal.Collection(ctx, name)
} }
type Collection struct { type Collection[T any] struct {
internal driver.Collection internal driver.Collection
db *BusDatabase db *BusDatabase
} }
func NewCollection(internal driver.Collection, db *BusDatabase) *Collection { func NewCollection[T any](internal driver.Collection, db *BusDatabase) *Collection[T] {
return &Collection{internal: internal, db: db} return &Collection[T]{internal: internal, db: db}
} }
func (c Collection) CreateDocument(ctx, newctx context.Context, key string, document interface{}) (meta driver.DocumentMeta, err error) { func (c *Collection[T]) CreateDocument(ctx, newctx context.Context, key string, document *T) (meta driver.DocumentMeta, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
meta, err = c.internal.CreateDocument(newctx, &Keyed{Key: key, Doc: document}) meta, err = c.internal.CreateDocument(newctx, &Keyed[T]{Key: key, Doc: document})
if err != nil { if err != nil {
return meta, err return meta, err
} }
@@ -94,10 +96,11 @@ func (c Collection) CreateDocument(ctx, newctx context.Context, key string, docu
if err != nil { if err != nil {
return meta, err return meta, err
} }
return meta, nil return meta, nil
} }
func (c Collection) CreateEdge(ctx, newctx context.Context, edge *driver.EdgeDocument) (meta driver.DocumentMeta, err error) { func (c *Collection[T]) CreateEdge(ctx, newctx context.Context, edge *driver.EdgeDocument) (meta driver.DocumentMeta, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
meta, err = c.internal.CreateDocument(newctx, edge) meta, err = c.internal.CreateDocument(newctx, edge)
@@ -109,10 +112,11 @@ func (c Collection) CreateEdge(ctx, newctx context.Context, edge *driver.EdgeDoc
if err != nil { if err != nil {
return meta, err return meta, err
} }
return meta, nil return meta, nil
} }
func (c Collection) CreateEdges(ctx context.Context, edges []*driver.EdgeDocument) (meta driver.DocumentMetaSlice, err error) { func (c *Collection[T]) CreateEdges(ctx context.Context, edges []*driver.EdgeDocument) (meta driver.DocumentMetaSlice, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
metas, errs, err := c.internal.CreateDocuments(ctx, edges) metas, errs, err := c.internal.CreateDocuments(ctx, edges)
@@ -136,13 +140,13 @@ func (c Collection) CreateEdges(ctx context.Context, edges []*driver.EdgeDocumen
return metas, nil return metas, nil
} }
func (c Collection) DocumentExists(ctx context.Context, id string) (exists bool, err error) { func (c *Collection[T]) DocumentExists(ctx context.Context, id string) (exists bool, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
return c.internal.DocumentExists(ctx, id) return c.internal.DocumentExists(ctx, id)
} }
func (c Collection) ReadDocument(ctx context.Context, key string, result interface{}) (meta driver.DocumentMeta, err error) { func (c *Collection[T]) ReadDocument(ctx context.Context, key string, result *T) (meta driver.DocumentMeta, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
meta, err = c.internal.ReadDocument(ctx, key, result) meta, err = c.internal.ReadDocument(ctx, key, result)
@@ -150,7 +154,7 @@ func (c Collection) ReadDocument(ctx context.Context, key string, result interfa
return return
} }
func (c Collection) UpdateDocument(ctx context.Context, key string, update interface{}) (meta driver.DocumentMeta, err error) { func (c *Collection[T]) UpdateDocument(ctx context.Context, key string, update any) (meta driver.DocumentMeta, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
meta, err = c.internal.UpdateDocument(ctx, key, update) meta, err = c.internal.UpdateDocument(ctx, key, update)
@@ -161,7 +165,7 @@ func (c Collection) UpdateDocument(ctx context.Context, key string, update inter
return meta, c.db.bus.PublishDatabaseUpdate([]driver.DocumentID{meta.ID}, bus.DatabaseEntryUpdated) return meta, c.db.bus.PublishDatabaseUpdate([]driver.DocumentID{meta.ID}, bus.DatabaseEntryUpdated)
} }
func (c Collection) ReplaceDocument(ctx context.Context, key string, document interface{}) (meta driver.DocumentMeta, err error) { func (c *Collection[T]) ReplaceDocument(ctx context.Context, key string, document *T) (meta driver.DocumentMeta, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
meta, err = c.internal.ReplaceDocument(ctx, key, document) meta, err = c.internal.ReplaceDocument(ctx, key, document)
@@ -172,13 +176,13 @@ func (c Collection) ReplaceDocument(ctx context.Context, key string, document in
return meta, c.db.bus.PublishDatabaseUpdate([]driver.DocumentID{meta.ID}, bus.DatabaseEntryUpdated) return meta, c.db.bus.PublishDatabaseUpdate([]driver.DocumentID{meta.ID}, bus.DatabaseEntryUpdated)
} }
func (c Collection) RemoveDocument(ctx context.Context, formatInt string) (meta driver.DocumentMeta, err error) { func (c *Collection[T]) RemoveDocument(ctx context.Context, formatInt string) (meta driver.DocumentMeta, err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
return c.internal.RemoveDocument(ctx, formatInt) return c.internal.RemoveDocument(ctx, formatInt)
} }
func (c Collection) Truncate(ctx context.Context) (err error) { func (c *Collection[T]) Truncate(ctx context.Context) (err error) {
defer func() { err = toHTTPErr(err) }() defer func() { err = toHTTPErr(err) }()
return c.internal.Truncate(ctx) return c.internal.Truncate(ctx)
@@ -190,7 +194,9 @@ func toHTTPErr(err error) error {
if errors.As(err, &ae) { if errors.As(err, &ae) {
return &api.HTTPError{Status: ae.Code, Internal: err} return &api.HTTPError{Status: ae.Code, Internal: err}
} }
return err return err
} }
return nil return nil
} }
+6 -2
View File
@@ -8,9 +8,11 @@ import (
"github.com/SecurityBrewery/catalyst/role" "github.com/SecurityBrewery/catalyst/role"
) )
type contextKey string
const ( const (
userContextKey = "user" userContextKey contextKey = "user"
groupContextKey = "groups" groupContextKey contextKey = "groups"
) )
func SetContext(r *http.Request, user *model.UserResponse) *http.Request { func SetContext(r *http.Request, user *model.UserResponse) *http.Request {
@@ -25,10 +27,12 @@ func SetGroupContext(r *http.Request, groups []string) *http.Request {
func UserContext(ctx context.Context, user *model.UserResponse) context.Context { func UserContext(ctx context.Context, user *model.UserResponse) context.Context {
user.Roles = role.Strings(role.Explodes(user.Roles)) user.Roles = role.Strings(role.Explodes(user.Roles))
return context.WithValue(ctx, userContextKey, user) return context.WithValue(ctx, userContextKey, user)
} }
func UserFromContext(ctx context.Context) (*model.UserResponse, bool) { func UserFromContext(ctx context.Context) (*model.UserResponse, bool) {
u, ok := ctx.Value(userContextKey).(*model.UserResponse) u, ok := ctx.Value(userContextKey).(*model.UserResponse)
return u, ok return u, ok
} }
+4 -4
View File
@@ -2,18 +2,18 @@ package busdb
import "encoding/json" import "encoding/json"
type Keyed struct { type Keyed[T any] struct {
Key string Key string
Doc interface{} Doc *T
} }
func (p Keyed) MarshalJSON() ([]byte, error) { func (p *Keyed[T]) MarshalJSON() ([]byte, error) {
b, err := json.Marshal(p.Doc) b, err := json.Marshal(p.Doc)
if err != nil { if err != nil {
panic(err) panic(err)
} }
var m map[string]interface{} var m map[string]any
err = json.Unmarshal(b, &m) err = json.Unmarshal(b, &m)
if err != nil { if err != nil {
panic(err) panic(err)
+8 -2
View File
@@ -3,6 +3,7 @@ package busdb
import ( import (
"context" "context"
"errors" "errors"
"log"
"strings" "strings"
"github.com/arangodb/go-driver" "github.com/arangodb/go-driver"
@@ -45,7 +46,12 @@ func (db *BusDatabase) LogBatchCreate(ctx context.Context, logentries []*model.L
} }
} }
if ids != nil { if ids != nil {
go db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryCreated) go func() {
err := db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryCreated)
if err != nil {
log.Println(err)
}
}()
} }
_, errs, err := db.logCollection.CreateDocuments(ctx, logentries) _, errs, err := db.logCollection.CreateDocuments(ctx, logentries)
@@ -62,7 +68,7 @@ func (db *BusDatabase) LogBatchCreate(ctx context.Context, logentries []*model.L
func (db *BusDatabase) LogList(ctx context.Context, reference string) ([]*model.LogEntry, error) { func (db *BusDatabase) LogList(ctx context.Context, reference string) ([]*model.LogEntry, error) {
query := "FOR d IN @@collection FILTER d.reference == @reference SORT d.created DESC RETURN d" query := "FOR d IN @@collection FILTER d.reference == @reference SORT d.created DESC RETURN d"
cursor, err := db.internal.Query(ctx, query, map[string]interface{}{ cursor, err := db.internal.Query(ctx, query, map[string]any{
"@collection": LogCollectionName, "@collection": LogCollectionName,
"reference": reference, "reference": reference,
}) })
+5 -3
View File
@@ -72,12 +72,13 @@ func (db *Database) DashboardUpdate(ctx context.Context, id string, dashboard *m
func (db *Database) DashboardDelete(ctx context.Context, id string) error { func (db *Database) DashboardDelete(ctx context.Context, id string) error {
_, err := db.dashboardCollection.RemoveDocument(ctx, id) _, err := db.dashboardCollection.RemoveDocument(ctx, id)
return err return err
} }
func (db *Database) DashboardList(ctx context.Context) ([]*model.DashboardResponse, error) { func (db *Database) DashboardList(ctx context.Context) ([]*model.DashboardResponse, error) {
query := "FOR d IN @@collection RETURN d" query := "FOR d IN @@collection RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": DashboardCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": DashboardCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -103,15 +104,16 @@ func (db *Database) parseWidgets(dashboard *model.Dashboard) error {
_, err := parser.Parse(widget.Aggregation) _, err := parser.Parse(widget.Aggregation)
if err != nil { if err != nil {
return fmt.Errorf("invalid aggregation query (%s): syntax error\n", widget.Aggregation) return fmt.Errorf("invalid aggregation query (%s): syntax error", widget.Aggregation)
} }
if widget.Filter != nil { if widget.Filter != nil {
_, err := parser.Parse(*widget.Filter) _, err := parser.Parse(*widget.Filter)
if err != nil { if err != nil {
return fmt.Errorf("invalid filter query (%s): syntax error\n", *widget.Filter) return fmt.Errorf("invalid filter query (%s): syntax error", *widget.Filter)
} }
} }
} }
return nil return nil
} }
+34 -33
View File
@@ -11,6 +11,7 @@ import (
"github.com/SecurityBrewery/catalyst/bus" "github.com/SecurityBrewery/catalyst/bus"
"github.com/SecurityBrewery/catalyst/database/busdb" "github.com/SecurityBrewery/catalyst/database/busdb"
"github.com/SecurityBrewery/catalyst/database/migrations" "github.com/SecurityBrewery/catalyst/database/migrations"
"github.com/SecurityBrewery/catalyst/generated/model"
"github.com/SecurityBrewery/catalyst/hooks" "github.com/SecurityBrewery/catalyst/hooks"
"github.com/SecurityBrewery/catalyst/index" "github.com/SecurityBrewery/catalyst/index"
) )
@@ -38,18 +39,18 @@ type Database struct {
bus *bus.Bus bus *bus.Bus
Hooks *hooks.Hooks Hooks *hooks.Hooks
templateCollection *busdb.Collection templateCollection *busdb.Collection[model.TicketTemplate]
ticketCollection *busdb.Collection ticketCollection *busdb.Collection[model.Ticket]
playbookCollection *busdb.Collection playbookCollection *busdb.Collection[model.PlaybookTemplate]
automationCollection *busdb.Collection automationCollection *busdb.Collection[model.Automation]
userdataCollection *busdb.Collection userdataCollection *busdb.Collection[model.UserData]
userCollection *busdb.Collection userCollection *busdb.Collection[model.User]
tickettypeCollection *busdb.Collection tickettypeCollection *busdb.Collection[model.TicketType]
jobCollection *busdb.Collection jobCollection *busdb.Collection[model.Job]
settingsCollection *busdb.Collection settingsCollection *busdb.Collection[model.Settings]
dashboardCollection *busdb.Collection dashboardCollection *busdb.Collection[model.Dashboard]
relatedCollection *busdb.Collection relatedCollection *busdb.Collection[driver.EdgeDocument]
// containsCollection *busdb.Collection // containsCollection *busdb.Collection
} }
@@ -145,17 +146,17 @@ func New(ctx context.Context, index *index.Index, bus *bus.Bus, hooks *hooks.Hoo
bus: bus, bus: bus,
Index: index, Index: index,
Hooks: hooks, Hooks: hooks,
templateCollection: busdb.NewCollection(templateCollection, hookedDB), templateCollection: busdb.NewCollection[model.TicketTemplate](templateCollection, hookedDB),
ticketCollection: busdb.NewCollection(ticketCollection, hookedDB), ticketCollection: busdb.NewCollection[model.Ticket](ticketCollection, hookedDB),
playbookCollection: busdb.NewCollection(playbookCollection, hookedDB), playbookCollection: busdb.NewCollection[model.PlaybookTemplate](playbookCollection, hookedDB),
automationCollection: busdb.NewCollection(automationCollection, hookedDB), automationCollection: busdb.NewCollection[model.Automation](automationCollection, hookedDB),
relatedCollection: busdb.NewCollection(relatedCollection, hookedDB), userdataCollection: busdb.NewCollection[model.UserData](userdataCollection, hookedDB),
userdataCollection: busdb.NewCollection(userdataCollection, hookedDB), userCollection: busdb.NewCollection[model.User](userCollection, hookedDB),
userCollection: busdb.NewCollection(userCollection, hookedDB), tickettypeCollection: busdb.NewCollection[model.TicketType](tickettypeCollection, hookedDB),
tickettypeCollection: busdb.NewCollection(tickettypeCollection, hookedDB), jobCollection: busdb.NewCollection[model.Job](jobCollection, hookedDB),
jobCollection: busdb.NewCollection(jobCollection, hookedDB), settingsCollection: busdb.NewCollection[model.Settings](settingsCollection, hookedDB),
settingsCollection: busdb.NewCollection(settingsCollection, hookedDB), dashboardCollection: busdb.NewCollection[model.Dashboard](dashboardCollection, hookedDB),
dashboardCollection: busdb.NewCollection(dashboardCollection, hookedDB), relatedCollection: busdb.NewCollection[driver.EdgeDocument](relatedCollection, hookedDB),
} }
return db, nil return db, nil
@@ -194,16 +195,16 @@ func SetupDB(ctx context.Context, client driver.Client, dbName string) (driver.D
} }
func (db *Database) Truncate(ctx context.Context) { func (db *Database) Truncate(ctx context.Context) {
db.templateCollection.Truncate(ctx) _ = db.templateCollection.Truncate(ctx)
db.ticketCollection.Truncate(ctx) _ = db.ticketCollection.Truncate(ctx)
db.playbookCollection.Truncate(ctx) _ = db.playbookCollection.Truncate(ctx)
db.automationCollection.Truncate(ctx) _ = db.automationCollection.Truncate(ctx)
db.userdataCollection.Truncate(ctx) _ = db.userdataCollection.Truncate(ctx)
db.userCollection.Truncate(ctx) _ = db.userCollection.Truncate(ctx)
db.tickettypeCollection.Truncate(ctx) _ = db.tickettypeCollection.Truncate(ctx)
db.jobCollection.Truncate(ctx) _ = db.jobCollection.Truncate(ctx)
db.relatedCollection.Truncate(ctx) _ = db.relatedCollection.Truncate(ctx)
db.settingsCollection.Truncate(ctx) _ = db.settingsCollection.Truncate(ctx)
db.dashboardCollection.Truncate(ctx) _ = db.dashboardCollection.Truncate(ctx)
// db.containsCollection.Truncate(ctx) // db.containsCollection.Truncate(ctx)
} }
+14 -12
View File
@@ -38,7 +38,7 @@ func (db *Database) toJobResponse(ctx context.Context, key string, doc *model.Jo
inspect, err := cli.ContainerInspect(ctx, key) inspect, err := cli.ContainerInspect(ctx, key)
if err != nil || inspect.State == nil { if err != nil || inspect.State == nil {
if update { if update {
db.JobUpdate(ctx, key, &model.JobUpdate{ _, _ = db.JobUpdate(ctx, key, &model.JobUpdate{
Status: doc.Status, Status: doc.Status,
Running: false, Running: false,
}) })
@@ -46,7 +46,7 @@ func (db *Database) toJobResponse(ctx context.Context, key string, doc *model.Jo
} else if doc.Status != inspect.State.Status { } else if doc.Status != inspect.State.Status {
status = inspect.State.Status status = inspect.State.Status
if update { if update {
db.JobUpdate(ctx, key, &model.JobUpdate{ _, _ = db.JobUpdate(ctx, key, &model.JobUpdate{
Status: status, Status: status,
Running: doc.Running, Running: doc.Running,
}) })
@@ -107,7 +107,7 @@ func (db *Database) JobUpdate(ctx context.Context, id string, job *model.JobUpda
func (db *Database) JobLogAppend(ctx context.Context, id string, logLine string) error { func (db *Database) JobLogAppend(ctx context.Context, id string, logLine string) error {
query := `LET d = DOCUMENT(@@collection, @ID) query := `LET d = DOCUMENT(@@collection, @ID)
UPDATE d WITH { "log": CONCAT(NOT_NULL(d.log, ""), @logline) } IN @@collection` UPDATE d WITH { "log": CONCAT(NOT_NULL(d.log, ""), @logline) } IN @@collection`
cur, _, err := db.Query(ctx, query, map[string]interface{}{ cur, _, err := db.Query(ctx, query, map[string]any{
"@collection": JobCollectionName, "@collection": JobCollectionName,
"ID": id, "ID": id,
"logline": logLine, "logline": logLine,
@@ -125,10 +125,10 @@ func (db *Database) JobLogAppend(ctx context.Context, id string, logLine string)
return nil return nil
} }
func (db *Database) JobComplete(ctx context.Context, id string, out interface{}) error { func (db *Database) JobComplete(ctx context.Context, id string, out any) error {
query := `LET d = DOCUMENT(@@collection, @ID) query := `LET d = DOCUMENT(@@collection, @ID)
UPDATE d WITH { "output": @out, "status": "completed", "running": false } IN @@collection` UPDATE d WITH { "output": @out, "status": "completed", "running": false } IN @@collection`
cur, _, err := db.Query(ctx, query, map[string]interface{}{ cur, _, err := db.Query(ctx, query, map[string]any{
"@collection": JobCollectionName, "@collection": JobCollectionName,
"ID": id, "ID": id,
"out": out, "out": out,
@@ -148,12 +148,13 @@ func (db *Database) JobComplete(ctx context.Context, id string, out interface{})
func (db *Database) JobDelete(ctx context.Context, id string) error { func (db *Database) JobDelete(ctx context.Context, id string) error {
_, err := db.jobCollection.RemoveDocument(ctx, id) _, err := db.jobCollection.RemoveDocument(ctx, id)
return err return err
} }
func (db *Database) JobList(ctx context.Context) ([]*model.JobResponse, error) { func (db *Database) JobList(ctx context.Context) ([]*model.JobResponse, error) {
query := "FOR d IN @@collection RETURN d" query := "FOR d IN @@collection RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": JobCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": JobCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -188,24 +189,24 @@ func publishJobMapping(id, automation string, contextStructs *model.Context, ori
return publishJob(id, automation, contextStructs, origin, msg, db) return publishJob(id, automation, contextStructs, origin, msg, db)
} }
func publishJob(id, automation string, contextStructs *model.Context, origin *model.Origin, payload map[string]interface{}, db *Database) error { func publishJob(id, automation string, contextStructs *model.Context, origin *model.Origin, payload map[string]any, db *Database) error {
return db.bus.PublishJob(id, automation, payload, contextStructs, origin) return db.bus.PublishJob(id, automation, payload, contextStructs, origin)
} }
func generatePayload(msgMapping map[string]string, contextStructs *model.Context) (map[string]interface{}, error) { func generatePayload(msgMapping map[string]string, contextStructs *model.Context) (map[string]any, error) {
contextJson, err := json.Marshal(contextStructs) contextJSON, err := json.Marshal(contextStructs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
automationContext := map[string]interface{}{} automationContext := map[string]any{}
err = json.Unmarshal(contextJson, &automationContext) err = json.Unmarshal(contextJSON, &automationContext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
parser := caql.Parser{} parser := caql.Parser{}
msg := map[string]interface{}{} msg := map[string]any{}
for arg, expr := range msgMapping { for arg, expr := range msgMapping {
tree, err := parser.Parse(expr) tree, err := parser.Parse(expr)
if err != nil { if err != nil {
@@ -218,5 +219,6 @@ func generatePayload(msgMapping map[string]string, contextStructs *model.Context
} }
msg[arg] = v msg[arg] = v
} }
return msg, nil return msg, nil
} }
+30 -24
View File
@@ -32,34 +32,34 @@ func generateMigrations() ([]Migration, error) {
&createGraph{ID: "create-ticket-graph", Name: "Graph", EdgeDefinitions: []driver.EdgeDefinition{{Collection: "related", From: []string{"tickets"}, To: []string{"tickets"}}}}, &createGraph{ID: "create-ticket-graph", Name: "Graph", EdgeDefinitions: []driver.EdgeDefinition{{Collection: "related", From: []string{"tickets"}, To: []string{"tickets"}}}},
&createDocument{ID: "create-template-default", Collection: "templates", Document: &busdb.Keyed{Key: "default", Doc: model.TicketTemplate{Schema: DefaultTemplateSchema, Name: "Default"}}}, &createDocument[busdb.Keyed[model.TicketTemplate]]{ID: "create-template-default", Collection: "templates", Document: &busdb.Keyed[model.TicketTemplate]{Key: "default", Doc: &model.TicketTemplate{Schema: DefaultTemplateSchema, Name: "Default"}}},
&createDocument{ID: "create-automation-vt.hash", Collection: "automations", Document: &busdb.Keyed{Key: "vt.hash", Doc: model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation}}}, &createDocument[busdb.Keyed[model.Automation]]{ID: "create-automation-vt.hash", Collection: "automations", Document: &busdb.Keyed[model.Automation]{Key: "vt.hash", Doc: &model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation}}},
&createDocument{ID: "create-automation-comment", Collection: "automations", Document: &busdb.Keyed{Key: "comment", Doc: model.Automation{Image: "docker.io/python:3", Script: CommentAutomation}}}, &createDocument[busdb.Keyed[model.Automation]]{ID: "create-automation-comment", Collection: "automations", Document: &busdb.Keyed[model.Automation]{Key: "comment", Doc: &model.Automation{Image: "docker.io/python:3", Script: CommentAutomation}}},
&createDocument{ID: "create-automation-hash.sha1", Collection: "automations", Document: &busdb.Keyed{Key: "hash.sha1", Doc: model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation}}}, &createDocument[busdb.Keyed[model.Automation]]{ID: "create-automation-hash.sha1", Collection: "automations", Document: &busdb.Keyed[model.Automation]{Key: "hash.sha1", Doc: &model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation}}},
&createDocument{ID: "create-playbook-malware", Collection: "playbooks", Document: &busdb.Keyed{Key: "malware", Doc: model.PlaybookTemplate{Name: "Malware", Yaml: MalwarePlaybook}}}, &createDocument[busdb.Keyed[model.PlaybookTemplate]]{ID: "create-playbook-malware", Collection: "playbooks", Document: &busdb.Keyed[model.PlaybookTemplate]{Key: "malware", Doc: &model.PlaybookTemplate{Name: "Malware", Yaml: MalwarePlaybook}}},
&createDocument{ID: "create-playbook-phishing", Collection: "playbooks", Document: &busdb.Keyed{Key: "phishing", Doc: model.PlaybookTemplate{Name: "Phishing", Yaml: PhishingPlaybook}}}, &createDocument[busdb.Keyed[model.PlaybookTemplate]]{ID: "create-playbook-phishing", Collection: "playbooks", Document: &busdb.Keyed[model.PlaybookTemplate]{Key: "phishing", Doc: &model.PlaybookTemplate{Name: "Phishing", Yaml: PhishingPlaybook}}},
&createDocument{ID: "create-tickettype-alert", Collection: "tickettypes", Document: &busdb.Keyed{Key: "alert", Doc: model.TicketType{Name: "Alerts", Icon: "mdi-alert", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}}, &createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-alert", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "alert", Doc: &model.TicketType{Name: "Alerts", Icon: "mdi-alert", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
&createDocument{ID: "create-tickettype-incident", Collection: "tickettypes", Document: &busdb.Keyed{Key: "incident", Doc: model.TicketType{Name: "Incidents", Icon: "mdi-radioactive", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}}, &createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-incident", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "incident", Doc: &model.TicketType{Name: "Incidents", Icon: "mdi-radioactive", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
&createDocument{ID: "create-tickettype-investigation", Collection: "tickettypes", Document: &busdb.Keyed{Key: "investigation", Doc: model.TicketType{Name: "Forensic Investigations", Icon: "mdi-fingerprint", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}}, &createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-investigation", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "investigation", Doc: &model.TicketType{Name: "Forensic Investigations", Icon: "mdi-fingerprint", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
&createDocument{ID: "create-tickettype-hunt", Collection: "tickettypes", Document: &busdb.Keyed{Key: "hunt", Doc: model.TicketType{Name: "Threat Hunting", Icon: "mdi-target", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}}, &createDocument[busdb.Keyed[model.TicketType]]{ID: "create-tickettype-hunt", Collection: "tickettypes", Document: &busdb.Keyed[model.TicketType]{Key: "hunt", Doc: &model.TicketType{Name: "Threat Hunting", Icon: "mdi-target", DefaultTemplate: "default", DefaultPlaybooks: []string{}, DefaultGroups: nil}}},
&updateSchema{ID: "update-automation-collection-1", Name: "automations", DataType: "automation", Schema: `{"properties":{"image":{"type":"string"},"script":{"type":"string"}},"required":["image","script"],"type":"object"}`}, &updateSchema{ID: "update-automation-collection-1", Name: "automations", DataType: "automation", Schema: `{"properties":{"image":{"type":"string"},"script":{"type":"string"}},"required":["image","script"],"type":"object"}`},
&updateDocument{ID: "update-automation-vt.hash-1", Collection: "automations", Key: "vt.hash", Document: model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}}, &updateDocument[model.Automation]{ID: "update-automation-vt.hash-1", Collection: "automations", Key: "vt.hash", Document: &model.Automation{Image: "docker.io/python:3", Script: VTHashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}},
&updateDocument{ID: "update-automation-comment-1", Collection: "automations", Key: "comment", Document: model.Automation{Image: "docker.io/python:3", Script: CommentAutomation, Type: []string{"playbook"}}}, &updateDocument[model.Automation]{ID: "update-automation-comment-1", Collection: "automations", Key: "comment", Document: &model.Automation{Image: "docker.io/python:3", Script: CommentAutomation, Type: []string{"playbook"}}},
&updateDocument{ID: "update-automation-hash.sha1-1", Collection: "automations", Key: "hash.sha1", Document: model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}}, &updateDocument[model.Automation]{ID: "update-automation-hash.sha1-1", Collection: "automations", Key: "hash.sha1", Document: &model.Automation{Image: "docker.io/python:3", Script: SHA1HashAutomation, Schema: pointer.String(`{"title":"Input","type":"object","properties":{"default":{"type":"string","title":"Value"}},"required":["default"]}`), Type: []string{"global", "artifact", "playbook"}}},
&createCollection{ID: "create-job-collection", Name: "jobs", DataType: "job", Schema: `{"properties":{"automation":{"type":"string"},"log":{"type":"string"},"payload":{},"origin":{"properties":{"artifact_origin":{"properties":{"artifact":{"type":"string"},"ticket_id":{"format":"int64","type":"integer"}},"required":["artifact","ticket_id"],"type":"object"},"task_origin":{"properties":{"playbook_id":{"type":"string"},"task_id":{"type":"string"},"ticket_id":{"format":"int64","type":"integer"}},"required":["playbook_id","task_id","ticket_id"],"type":"object"}},"type":"object"},"output":{"properties":{},"type":"object"},"running":{"type":"boolean"},"status":{"type":"string"}},"required":["automation","running","status"],"type":"object"}`}, &createCollection{ID: "create-job-collection", Name: "jobs", DataType: "job", Schema: `{"properties":{"automation":{"type":"string"},"log":{"type":"string"},"payload":{},"origin":{"properties":{"artifact_origin":{"properties":{"artifact":{"type":"string"},"ticket_id":{"format":"int64","type":"integer"}},"required":["artifact","ticket_id"],"type":"object"},"task_origin":{"properties":{"playbook_id":{"type":"string"},"task_id":{"type":"string"},"ticket_id":{"format":"int64","type":"integer"}},"required":["playbook_id","task_id","ticket_id"],"type":"object"}},"type":"object"},"output":{"properties":{},"type":"object"},"running":{"type":"boolean"},"status":{"type":"string"}},"required":["automation","running","status"],"type":"object"}`},
&createDocument{ID: "create-playbook-simple", Collection: "playbooks", Document: &busdb.Keyed{Key: "simple", Doc: model.PlaybookTemplate{Name: "Simple", Yaml: SimplePlaybook}}}, &createDocument[busdb.Keyed[model.PlaybookTemplate]]{ID: "create-playbook-simple", Collection: "playbooks", Document: &busdb.Keyed[model.PlaybookTemplate]{Key: "simple", Doc: &model.PlaybookTemplate{Name: "Simple", Yaml: SimplePlaybook}}},
&createCollection{ID: "create-settings-collection", Name: "settings", DataType: "settings", Schema: `{"type":"object","properties":{"artifactStates":{"title":"Artifact States","items":{"type":"object","properties":{"color":{"title":"Color","type":"string","enum":["error","info","success","warning"]},"icon":{"title":"Icon (https://materialdesignicons.com)","type":"string"},"id":{"title":"ID","type":"string"},"name":{"title":"Name","type":"string"}},"required":["id","name","icon"]},"type":"array"},"artifactKinds":{"title":"Artifact Kinds","items":{"type":"object","properties":{"color":{"title":"Color","type":"string","enum":["error","info","success","warning"]},"icon":{"title":"Icon (https://materialdesignicons.com)","type":"string"},"id":{"title":"ID","type":"string"},"name":{"title":"Name","type":"string"}},"required":["id","name","icon"]},"type":"array"},"timeformat":{"title":"Time Format","type":"string"}},"required":["timeformat","artifactKinds","artifactStates"]}`}, &createCollection{ID: "create-settings-collection", Name: "settings", DataType: "settings", Schema: `{"type":"object","properties":{"artifactStates":{"title":"Artifact States","items":{"type":"object","properties":{"color":{"title":"Color","type":"string","enum":["error","info","success","warning"]},"icon":{"title":"Icon (https://materialdesignicons.com)","type":"string"},"id":{"title":"ID","type":"string"},"name":{"title":"Name","type":"string"}},"required":["id","name","icon"]},"type":"array"},"artifactKinds":{"title":"Artifact Kinds","items":{"type":"object","properties":{"color":{"title":"Color","type":"string","enum":["error","info","success","warning"]},"icon":{"title":"Icon (https://materialdesignicons.com)","type":"string"},"id":{"title":"ID","type":"string"},"name":{"title":"Name","type":"string"}},"required":["id","name","icon"]},"type":"array"},"timeformat":{"title":"Time Format","type":"string"}},"required":["timeformat","artifactKinds","artifactStates"]}`},
&createDocument{ID: "create-settings-global", Collection: "settings", Document: &busdb.Keyed{Key: "global", Doc: model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "YYYY-MM-DDThh:mm:ss"}}}, &createDocument[busdb.Keyed[model.Settings]]{ID: "create-settings-global", Collection: "settings", Document: &busdb.Keyed[model.Settings]{Key: "global", Doc: &model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "YYYY-MM-DDThh:mm:ss"}}},
&updateSchema{ID: "update-ticket-collection", Name: "tickets", DataType: "ticket", Schema: `{"properties":{"artifacts":{"items":{"properties":{"enrichments":{"additionalProperties":{"properties":{"created":{"format":"date-time","type":"string"},"data":{"example":{"hash":"b7a067a742c20d07a7456646de89bc2d408a1153"},"properties":{},"type":"object"},"name":{"example":"hash.sha1","type":"string"}},"required":["created","data","name"],"type":"object"},"type":"object"},"name":{"example":"2.2.2.2","type":"string"},"status":{"example":"Unknown","type":"string"},"type":{"type":"string"},"kind":{"type":"string"}},"required":["name"],"type":"object"},"type":"array"},"comments":{"items":{"properties":{"created":{"format":"date-time","type":"string"},"creator":{"type":"string"},"message":{"type":"string"}},"required":["created","creator","message"],"type":"object"},"type":"array"},"created":{"format":"date-time","type":"string"},"details":{"example":{"description":"my little incident"},"properties":{},"type":"object"},"files":{"items":{"properties":{"key":{"example":"myfile","type":"string"},"name":{"example":"notes.docx","type":"string"}},"required":["key","name"],"type":"object"},"type":"array"},"modified":{"format":"date-time","type":"string"},"name":{"example":"WannyCry","type":"string"},"owner":{"example":"bob","type":"string"},"playbooks":{"additionalProperties":{"properties":{"name":{"example":"Phishing","type":"string"},"tasks":{"additionalProperties":{"properties":{"automation":{"type":"string"},"closed":{"format":"date-time","type":"string"},"created":{"format":"date-time","type":"string"},"data":{"properties":{},"type":"object"},"done":{"type":"boolean"},"join":{"example":false,"type":"boolean"},"payload":{"additionalProperties":{"type":"string"},"type":"object"},"name":{"example":"Inform user","type":"string"},"next":{"additionalProperties":{"type":"string"},"type":"object"},"owner":{"type":"string"},"schema":{"properties":{},"type":"object"},"type":{"enum":["task","input","automation"],"example":"task","type":"string"}},"required":["created","done","name","type"],"type":"object"},"type":"object"}},"required":["name","tasks"],"type":"object"},"type":"object"},"read":{"example":["bob"],"items":{"type":"string"},"type":"array"},"references":{"items":{"properties":{"href":{"example":"https://cve.mitre.org/cgi-bin/cvename.cgi?name=cve-2017-0144","type":"string"},"name":{"example":"CVE-2017-0144","type":"string"}},"required":["href","name"],"type":"object"},"type":"array"},"schema":{"example":"{}","type":"string"},"status":{"example":"open","type":"string"},"type":{"example":"incident","type":"string"},"write":{"example":["alice"],"items":{"type":"string"},"type":"array"}},"required":["created","modified","name","schema","status","type"],"type":"object"}`}, &updateSchema{ID: "update-ticket-collection", Name: "tickets", DataType: "ticket", Schema: `{"properties":{"artifacts":{"items":{"properties":{"enrichments":{"additionalProperties":{"properties":{"created":{"format":"date-time","type":"string"},"data":{"example":{"hash":"b7a067a742c20d07a7456646de89bc2d408a1153"},"properties":{},"type":"object"},"name":{"example":"hash.sha1","type":"string"}},"required":["created","data","name"],"type":"object"},"type":"object"},"name":{"example":"2.2.2.2","type":"string"},"status":{"example":"Unknown","type":"string"},"type":{"type":"string"},"kind":{"type":"string"}},"required":["name"],"type":"object"},"type":"array"},"comments":{"items":{"properties":{"created":{"format":"date-time","type":"string"},"creator":{"type":"string"},"message":{"type":"string"}},"required":["created","creator","message"],"type":"object"},"type":"array"},"created":{"format":"date-time","type":"string"},"details":{"example":{"description":"my little incident"},"properties":{},"type":"object"},"files":{"items":{"properties":{"key":{"example":"myfile","type":"string"},"name":{"example":"notes.docx","type":"string"}},"required":["key","name"],"type":"object"},"type":"array"},"modified":{"format":"date-time","type":"string"},"name":{"example":"WannyCry","type":"string"},"owner":{"example":"bob","type":"string"},"playbooks":{"additionalProperties":{"properties":{"name":{"example":"Phishing","type":"string"},"tasks":{"additionalProperties":{"properties":{"automation":{"type":"string"},"closed":{"format":"date-time","type":"string"},"created":{"format":"date-time","type":"string"},"data":{"properties":{},"type":"object"},"done":{"type":"boolean"},"join":{"example":false,"type":"boolean"},"payload":{"additionalProperties":{"type":"string"},"type":"object"},"name":{"example":"Inform user","type":"string"},"next":{"additionalProperties":{"type":"string"},"type":"object"},"owner":{"type":"string"},"schema":{"properties":{},"type":"object"},"type":{"enum":["task","input","automation"],"example":"task","type":"string"}},"required":["created","done","name","type"],"type":"object"},"type":"object"}},"required":["name","tasks"],"type":"object"},"type":"object"},"read":{"example":["bob"],"items":{"type":"string"},"type":"array"},"references":{"items":{"properties":{"href":{"example":"https://cve.mitre.org/cgi-bin/cvename.cgi?name=cve-2017-0144","type":"string"},"name":{"example":"CVE-2017-0144","type":"string"}},"required":["href","name"],"type":"object"},"type":"array"},"schema":{"example":"{}","type":"string"},"status":{"example":"open","type":"string"},"type":{"example":"incident","type":"string"},"write":{"example":["alice"],"items":{"type":"string"},"type":"array"}},"required":["created","modified","name","schema","status","type"],"type":"object"}`},
&createCollection{ID: "create-dashboard-collection", Name: "dashboards", DataType: "dashboards", Schema: `{"type":"object","properties":{"name":{"type":"string"},"widgets":{"items":{"type":"object","properties":{"aggregation":{"type":"string"},"filter":{"type":"string"},"name":{"type":"string"},"type":{"enum":[ "bar", "line", "pie" ]},"width": { "type": "integer", "minimum": 1, "maximum": 12 }},"required":["name","aggregation", "type", "width"]},"type":"array"}},"required":["name","widgets"]}`}, &createCollection{ID: "create-dashboard-collection", Name: "dashboards", DataType: "dashboards", Schema: `{"type":"object","properties":{"name":{"type":"string"},"widgets":{"items":{"type":"object","properties":{"aggregation":{"type":"string"},"filter":{"type":"string"},"name":{"type":"string"},"type":{"enum":[ "bar", "line", "pie" ]},"width": { "type": "integer", "minimum": 1, "maximum": 12 }},"required":["name","aggregation", "type", "width"]},"type":"array"}},"required":["name","widgets"]}`},
&updateDocument{ID: "update-settings-global-1", Collection: "settings", Key: "global", Document: &model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "yyyy-MM-dd hh:mm:ss"}}, &updateDocument[model.Settings]{ID: "update-settings-global-1", Collection: "settings", Key: "global", Document: &model.Settings{ArtifactStates: []*model.Type{{Icon: "mdi-help-circle-outline", ID: "unknown", Name: "Unknown", Color: pointer.String(model.TypeColorInfo)}, {Icon: "mdi-skull", ID: "malicious", Name: "Malicious", Color: pointer.String(model.TypeColorError)}, {Icon: "mdi-check", ID: "clean", Name: "Clean", Color: pointer.String(model.TypeColorSuccess)}}, ArtifactKinds: []*model.Type{{Icon: "mdi-server", ID: "asset", Name: "Asset"}, {Icon: "mdi-bullseye", ID: "ioc", Name: "IOC"}}, Timeformat: "yyyy-MM-dd hh:mm:ss"}},
}, nil }, nil
} }
@@ -67,6 +67,7 @@ func loadSchema(dataType, jsonschema string) (*driver.CollectionSchemaOptions, e
ticketCollectionSchema := &driver.CollectionSchemaOptions{Level: driver.CollectionSchemaLevelStrict, Message: fmt.Sprintf("Validation of %s failed", dataType)} ticketCollectionSchema := &driver.CollectionSchemaOptions{Level: driver.CollectionSchemaLevelStrict, Message: fmt.Sprintf("Validation of %s failed", dataType)}
err := ticketCollectionSchema.LoadRule([]byte(jsonschema)) err := ticketCollectionSchema.LoadRule([]byte(jsonschema))
return ticketCollectionSchema, err return ticketCollectionSchema, err
} }
@@ -101,6 +102,7 @@ func PerformMigrations(ctx context.Context, db driver.Database) error {
} }
} }
} }
return nil return nil
} }
@@ -171,41 +173,43 @@ func (m *createGraph) Migrate(ctx context.Context, db driver.Database) error {
_, err := db.CreateGraph(ctx, m.Name, &driver.CreateGraphOptions{ _, err := db.CreateGraph(ctx, m.Name, &driver.CreateGraphOptions{
EdgeDefinitions: m.EdgeDefinitions, EdgeDefinitions: m.EdgeDefinitions,
}) })
return err return err
} }
type createDocument struct { type createDocument[T any] struct {
ID string ID string
Collection string Collection string
Document interface{} Document *T
} }
func (m *createDocument) MID() string { func (m *createDocument[T]) MID() string {
return m.ID return m.ID
} }
func (m *createDocument) Migrate(ctx context.Context, driver driver.Database) error { func (m *createDocument[T]) Migrate(ctx context.Context, driver driver.Database) error {
collection, err := driver.Collection(ctx, m.Collection) collection, err := driver.Collection(ctx, m.Collection)
if err != nil { if err != nil {
return err return err
} }
_, err = collection.CreateDocument(ctx, m.Document) _, err = collection.CreateDocument(ctx, m.Document)
return err return err
} }
type updateDocument struct { type updateDocument[T any] struct {
ID string ID string
Collection string Collection string
Key string Key string
Document interface{} Document *T
} }
func (m *updateDocument) MID() string { func (m *updateDocument[T]) MID() string {
return m.ID return m.ID
} }
func (m *updateDocument) Migrate(ctx context.Context, driver driver.Database) error { func (m *updateDocument[T]) Migrate(ctx context.Context, driver driver.Database) error {
collection, err := driver.Collection(ctx, m.Collection) collection, err := driver.Collection(ctx, m.Collection)
if err != nil { if err != nil {
return err return err
@@ -218,9 +222,11 @@ func (m *updateDocument) Migrate(ctx context.Context, driver driver.Database) er
if !exists { if !exists {
_, err = collection.CreateDocument(ctx, m.Document) _, err = collection.CreateDocument(ctx, m.Document)
return err return err
} }
_, err = collection.ReplaceDocument(ctx, m.Key, m.Document) _, err = collection.ReplaceDocument(ctx, m.Key, m.Document)
return err return err
} }
+13 -5
View File
@@ -22,7 +22,7 @@ type PlaybookYAML struct {
type TaskYAML struct { type TaskYAML struct {
Name string `yaml:"name"` Name string `yaml:"name"`
Type string `yaml:"type"` Type string `yaml:"type"`
Schema interface{} `yaml:"schema"` Schema any `yaml:"schema"`
Automation string `yaml:"automation"` Automation string `yaml:"automation"`
Payload map[string]string `yaml:"payload"` Payload map[string]string `yaml:"payload"`
Next map[string]string `yaml:"next"` Next map[string]string `yaml:"next"`
@@ -42,6 +42,7 @@ func toPlaybooks(docs []*model.PlaybookTemplateForm) (map[string]*model.Playbook
playbooks[strcase.ToKebab(playbook.Name)] = playbook playbooks[strcase.ToKebab(playbook.Name)] = playbook
} }
} }
return playbooks, nil return playbooks, nil
} }
@@ -53,11 +54,17 @@ func toPlaybook(doc *model.PlaybookTemplateForm) (*model.Playbook, error) {
} }
for idx, task := range ticketPlaybook.Tasks { for idx, task := range ticketPlaybook.Tasks {
if task.Schema != nil { if task.Schema != nil {
task.Schema = dyno.ConvertMapI2MapS(task.Schema).(map[string]interface{}) schema, ok := dyno.ConvertMapI2MapS(task.Schema).(map[string]any)
if ok {
task.Schema = schema
} else {
return nil, errors.New("could not convert schema")
}
} }
task.Created = time.Now().UTC() task.Created = time.Now().UTC()
ticketPlaybook.Tasks[idx] = task ticketPlaybook.Tasks[idx] = task
} }
return ticketPlaybook, nil return ticketPlaybook, nil
} }
@@ -84,7 +91,7 @@ func (db *Database) PlaybookCreate(ctx context.Context, playbook *model.Playbook
var doc model.PlaybookTemplate var doc model.PlaybookTemplate
newctx := driver.WithReturnNew(ctx, &doc) newctx := driver.WithReturnNew(ctx, &doc)
meta, err := db.playbookCollection.CreateDocument(ctx, newctx, strcase.ToKebab(playbookYAML.Name), p) meta, err := db.playbookCollection.CreateDocument(ctx, newctx, strcase.ToKebab(playbookYAML.Name), &p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -104,6 +111,7 @@ func (db *Database) PlaybookGet(ctx context.Context, id string) (*model.Playbook
func (db *Database) PlaybookDelete(ctx context.Context, id string) error { func (db *Database) PlaybookDelete(ctx context.Context, id string) error {
_, err := db.playbookCollection.RemoveDocument(ctx, id) _, err := db.playbookCollection.RemoveDocument(ctx, id)
return err return err
} }
@@ -121,7 +129,7 @@ func (db *Database) PlaybookUpdate(ctx context.Context, id string, playbook *mod
var doc model.PlaybookTemplate var doc model.PlaybookTemplate
ctx = driver.WithReturnNew(ctx, &doc) ctx = driver.WithReturnNew(ctx, &doc)
meta, err := db.playbookCollection.ReplaceDocument(ctx, id, model.PlaybookTemplate{Name: pb.Name, Yaml: playbook.Yaml}) meta, err := db.playbookCollection.ReplaceDocument(ctx, id, &model.PlaybookTemplate{Name: pb.Name, Yaml: playbook.Yaml})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -131,7 +139,7 @@ func (db *Database) PlaybookUpdate(ctx context.Context, id string, playbook *mod
func (db *Database) PlaybookList(ctx context.Context) ([]*model.PlaybookTemplateResponse, error) { func (db *Database) PlaybookList(ctx context.Context) ([]*model.PlaybookTemplateResponse, error) {
query := "FOR d IN @@collection RETURN d" query := "FOR d IN @@collection RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": PlaybookCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": PlaybookCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+7 -3
View File
@@ -33,6 +33,7 @@ func playbookGraph(playbook *model.Playbook) (*dag.Graph, error) {
} }
} }
} }
return d, nil return d, nil
} }
@@ -109,6 +110,7 @@ func active(playbook *model.Playbook, taskID string, d *dag.Graph, task *model.T
return false, nil return false, nil
} }
} }
return true, nil return true, nil
} }
@@ -129,10 +131,11 @@ func active(playbook *model.Playbook, taskID string, d *dag.Graph, task *model.T
return true, nil return true, nil
} }
} }
return false, nil return false, nil
} }
func evalRequirement(aql string, data interface{}) (bool, error) { func evalRequirement(aql string, data any) (bool, error) {
if aql == "" { if aql == "" {
return true, nil return true, nil
} }
@@ -143,9 +146,9 @@ func evalRequirement(aql string, data interface{}) (bool, error) {
return false, err return false, err
} }
var dataMap map[string]interface{} var dataMap map[string]any
if data != nil { if data != nil {
if dataMapX, ok := data.(map[string]interface{}); ok { if dataMapX, ok := data.(map[string]any); ok {
dataMap = dataMapX dataMap = dataMapX
} else { } else {
log.Println("wrong data type for task data") log.Println("wrong data type for task data")
@@ -160,6 +163,7 @@ func evalRequirement(aql string, data interface{}) (bool, error) {
if b, ok := v.(bool); ok { if b, ok := v.(bool); ok {
return b, nil return b, nil
} }
return false, err return false, err
} }
+21 -9
View File
@@ -12,11 +12,11 @@ var playbook2 = &model.Playbook{
Name: "Phishing", Name: "Phishing",
Tasks: map[string]*model.Task{ Tasks: map[string]*model.Task{
"board": {Next: map[string]string{ "board": {Next: map[string]string{
"escalate": "boardInvolved == true", "escalate": "boardInvolved == true",
"aquire-mail": "boardInvolved == false", "acquire-mail": "boardInvolved == false",
}}, }},
"escalate": {}, "escalate": {},
"aquire-mail": {Next: map[string]string{ "acquire-mail": {Next: map[string]string{
"extract-iocs": "schemaKey == 'yes'", "extract-iocs": "schemaKey == 'yes'",
"block-sender": "schemaKey == 'yes'", "block-sender": "schemaKey == 'yes'",
"search-email-gateway": "schemaKey == 'no'", "search-email-gateway": "schemaKey == 'no'",
@@ -34,11 +34,11 @@ var playbook3 = &model.Playbook{
Name: "Phishing", Name: "Phishing",
Tasks: map[string]*model.Task{ Tasks: map[string]*model.Task{
"board": {Next: map[string]string{ "board": {Next: map[string]string{
"escalate": "boardInvolved == true", "escalate": "boardInvolved == true",
"aquire-mail": "boardInvolved == false", "acquire-mail": "boardInvolved == false",
}, Data: map[string]interface{}{"boardInvolved": true}, Done: true}, }, Data: map[string]any{"boardInvolved": true}, Done: true},
"escalate": {}, "escalate": {},
"aquire-mail": {Next: map[string]string{ "acquire-mail": {Next: map[string]string{
"extract-iocs": "schemaKey == 'yes'", "extract-iocs": "schemaKey == 'yes'",
"block-sender": "schemaKey == 'yes'", "block-sender": "schemaKey == 'yes'",
"search-email-gateway": "schemaKey == 'no'", "search-email-gateway": "schemaKey == 'no'",
@@ -71,6 +71,8 @@ var playbook4 = &model.Playbook{
} }
func Test_canBeCompleted(t *testing.T) { func Test_canBeCompleted(t *testing.T) {
t.Parallel()
type args struct { type args struct {
playbook *model.Playbook playbook *model.Playbook
taskID string taskID string
@@ -83,18 +85,22 @@ func Test_canBeCompleted(t *testing.T) {
}{ }{
{"playbook2 board", args{playbook: playbook2, taskID: "board"}, true, false}, {"playbook2 board", args{playbook: playbook2, taskID: "board"}, true, false},
{"playbook2 escalate", args{playbook: playbook2, taskID: "escalate"}, false, false}, {"playbook2 escalate", args{playbook: playbook2, taskID: "escalate"}, false, false},
{"playbook2 aquire-mail", args{playbook: playbook2, taskID: "aquire-mail"}, false, false}, {"playbook2 acquire-mail", args{playbook: playbook2, taskID: "acquire-mail"}, false, false},
{"playbook2 block-ioc", args{playbook: playbook2, taskID: "block-ioc"}, false, false}, {"playbook2 block-ioc", args{playbook: playbook2, taskID: "block-ioc"}, false, false},
{"playbook3 board", args{playbook: playbook3, taskID: "board"}, false, false}, {"playbook3 board", args{playbook: playbook3, taskID: "board"}, false, false},
{"playbook3 escalate", args{playbook: playbook3, taskID: "escalate"}, true, false}, {"playbook3 escalate", args{playbook: playbook3, taskID: "escalate"}, true, false},
{"playbook3 aquire-mail", args{playbook: playbook3, taskID: "aquire-mail"}, false, false}, {"playbook3 acquire-mail", args{playbook: playbook3, taskID: "acquire-mail"}, false, false},
{"playbook3 block-ioc", args{playbook: playbook3, taskID: "block-ioc"}, false, false}, {"playbook3 block-ioc", args{playbook: playbook3, taskID: "block-ioc"}, false, false},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := activePlaybook(tt.args.playbook, tt.args.taskID) got, err := activePlaybook(tt.args.playbook, tt.args.taskID)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("activePlaybook() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("activePlaybook() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if got != tt.want { if got != tt.want {
@@ -105,6 +111,8 @@ func Test_canBeCompleted(t *testing.T) {
} }
func Test_playbookOrder(t *testing.T) { func Test_playbookOrder(t *testing.T) {
t.Parallel()
type args struct { type args struct {
playbook *model.Playbook playbook *model.Playbook
} }
@@ -117,10 +125,14 @@ func Test_playbookOrder(t *testing.T) {
{"playbook4", args{playbook: playbook4}, []string{"file-or-hash", "enter-hash", "upload", "hash", "virustotal"}, false}, {"playbook4", args{playbook: playbook4}, []string{"file-or-hash", "enter-hash", "upload", "hash", "virustotal"}, false},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := toPlaybookResponse(tt.args.playbook) got, err := toPlaybookResponse(tt.args.playbook)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("activePlaybook() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("activePlaybook() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
+4 -1
View File
@@ -20,11 +20,13 @@ func (db *Database) RelatedCreate(ctx context.Context, id, id2 int64) error {
From: driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id))), From: driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id))),
To: driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))), To: driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))),
}) })
return err return err
} }
func (db *Database) RelatedBatchCreate(ctx context.Context, edges []*driver.EdgeDocument) error { func (db *Database) RelatedBatchCreate(ctx context.Context, edges []*driver.EdgeDocument) error {
_, err := db.relatedCollection.CreateEdges(ctx, edges) _, err := db.relatedCollection.CreateEdges(ctx, edges)
return err return err
} }
@@ -33,7 +35,7 @@ func (db *Database) RelatedRemove(ctx context.Context, id, id2 int64) error {
FOR d in @@collection FOR d in @@collection
FILTER (d._from == @id && d._to == @id2) || (d._to == @id && d._from == @id2) FILTER (d._from == @id && d._to == @id2) || (d._to == @id && d._from == @id2)
REMOVE d in @@collection` REMOVE d in @@collection`
_, _, err := db.Query(ctx, q, map[string]interface{}{ _, _, err := db.Query(ctx, q, map[string]any{
"@collection": RelatedTicketsCollectionName, "@collection": RelatedTicketsCollectionName,
"id": driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id))), "id": driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id))),
"id2": driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))), "id2": driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))),
@@ -44,5 +46,6 @@ func (db *Database) RelatedRemove(ctx context.Context, id, id2 int64) error {
driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))), driver.DocumentID(TicketCollectionName + "/" + strconv.Itoa(int(id2))),
}, },
}) })
return err return err
} }
+4 -4
View File
@@ -44,12 +44,12 @@ func (db *Database) Statistics(ctx context.Context) (*model.Statistics, error) {
return &statistics, nil return &statistics, nil
} }
func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *string) (map[string]interface{}, error) { func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *string) (map[string]any, error) {
parser := &caql.Parser{Searcher: db.Index, Prefix: "d."} parser := &caql.Parser{Searcher: db.Index, Prefix: "d."}
queryTree, err := parser.Parse(aggregation) queryTree, err := parser.Parse(aggregation)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid aggregation query (%s): syntax error\n", aggregation) return nil, fmt.Errorf("invalid aggregation query (%s): syntax error", aggregation)
} }
aggregationString, err := queryTree.String() aggregationString, err := queryTree.String()
if err != nil { if err != nil {
@@ -61,7 +61,7 @@ func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *
if filter != nil && *filter != "" { if filter != nil && *filter != "" {
queryTree, err := parser.Parse(*filter) queryTree, err := parser.Parse(*filter)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid filter query (%s): syntax error\n", *filter) return nil, fmt.Errorf("invalid filter query (%s): syntax error", *filter)
} }
filterString, err := queryTree.String() filterString, err := queryTree.String()
if err != nil { if err != nil {
@@ -82,7 +82,7 @@ func (db *Database) WidgetData(ctx context.Context, aggregation string, filter *
} }
defer cur.Close() defer cur.Close()
statistics := map[string]interface{}{} statistics := map[string]any{}
if _, err := cur.ReadDocument(ctx, &statistics); err != nil { if _, err := cur.ReadDocument(ctx, &statistics); err != nil {
return nil, err return nil, err
} }
+5 -5
View File
@@ -10,10 +10,10 @@ import (
) )
type playbookResponse struct { type playbookResponse struct {
PlaybookId string `json:"playbook_id"` PlaybookID string `json:"playbook_id"`
PlaybookName string `json:"playbook_name"` PlaybookName string `json:"playbook_name"`
Playbook model.Playbook `json:"playbook"` Playbook model.Playbook `json:"playbook"`
TicketId int64 `json:"ticket_id"` TicketID int64 `json:"ticket_id"`
TicketName string `json:"ticket_name"` TicketName string `json:"ticket_name"`
} }
@@ -28,7 +28,7 @@ func (db *Database) TaskList(ctx context.Context) ([]*model.TaskWithContext, err
FILTER d.status == 'open' FILTER d.status == 'open'
FOR playbook IN NOT_NULL(VALUES(d.playbooks), []) FOR playbook IN NOT_NULL(VALUES(d.playbooks), [])
RETURN { ticket_id: TO_NUMBER(d._key), ticket_name: d.name, playbook_id: POSITION(d.playbooks, playbook, true), playbook_name: playbook.name, playbook: playbook }` RETURN { ticket_id: TO_NUMBER(d._key), ticket_name: d.name, playbook_id: POSITION(d.playbooks, playbook, true), playbook_name: playbook.name, playbook: playbook }`
cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]interface{}{ cursor, _, err := db.Query(ctx, query, mergeMaps(ticketFilterVars, map[string]any{
"@collection": TicketCollectionName, "@collection": TicketCollectionName,
}), busdb.ReadOperation) }), busdb.ReadOperation)
if err != nil { if err != nil {
@@ -53,10 +53,10 @@ func (db *Database) TaskList(ctx context.Context) ([]*model.TaskWithContext, err
for _, task := range playbook.Tasks { for _, task := range playbook.Tasks {
if task.Active { if task.Active {
docs = append(docs, &model.TaskWithContext{ docs = append(docs, &model.TaskWithContext{
PlaybookId: doc.PlaybookId, PlaybookId: doc.PlaybookID,
PlaybookName: doc.PlaybookName, PlaybookName: doc.PlaybookName,
Task: task, Task: task,
TicketId: doc.TicketId, TicketId: doc.TicketID,
TicketName: doc.TicketName, TicketName: doc.TicketName,
}) })
} }
+2 -1
View File
@@ -62,12 +62,13 @@ func (db *Database) TemplateUpdate(ctx context.Context, id string, template *mod
func (db *Database) TemplateDelete(ctx context.Context, id string) error { func (db *Database) TemplateDelete(ctx context.Context, id string) error {
_, err := db.templateCollection.RemoveDocument(ctx, id) _, err := db.templateCollection.RemoveDocument(ctx, id)
return err return err
} }
func (db *Database) TemplateList(ctx context.Context) ([]*model.TicketTemplateResponse, error) { func (db *Database) TemplateList(ctx context.Context) ([]*model.TicketTemplateResponse, error) {
query := "FOR d IN @@collection RETURN d" query := "FOR d IN @@collection RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": TemplateCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": TemplateCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+37 -8
View File
@@ -10,16 +10,20 @@ import (
"github.com/SecurityBrewery/catalyst/test" "github.com/SecurityBrewery/catalyst/test"
) )
var template1 = &model.TicketTemplateForm{ var (
Schema: migrations.DefaultTemplateSchema, template1 = &model.TicketTemplateForm{
Name: "Template 1", Schema: migrations.DefaultTemplateSchema,
} Name: "Template 1",
var default1 = &model.TicketTemplateForm{ }
Schema: migrations.DefaultTemplateSchema, default1 = &model.TicketTemplateForm{
Name: "Default", Schema: migrations.DefaultTemplateSchema,
} Name: "Default",
}
)
func TestDatabase_TemplateCreate(t *testing.T) { func TestDatabase_TemplateCreate(t *testing.T) {
t.Parallel()
type args struct { type args struct {
template *model.TicketTemplateForm template *model.TicketTemplateForm
} }
@@ -35,7 +39,10 @@ func TestDatabase_TemplateCreate(t *testing.T) {
{name: "Only name", args: args{template: &model.TicketTemplateForm{Name: "name"}}, wantErr: false}, {name: "Only name", args: args{template: &model.TicketTemplateForm{Name: "name"}}, wantErr: false},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -50,6 +57,8 @@ func TestDatabase_TemplateCreate(t *testing.T) {
} }
func TestDatabase_TemplateDelete(t *testing.T) { func TestDatabase_TemplateDelete(t *testing.T) {
t.Parallel()
type args struct { type args struct {
id string id string
} }
@@ -62,7 +71,10 @@ func TestDatabase_TemplateDelete(t *testing.T) {
{name: "Not existing", args: args{"foobar"}, wantErr: true}, {name: "Not existing", args: args{"foobar"}, wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -81,6 +93,8 @@ func TestDatabase_TemplateDelete(t *testing.T) {
} }
func TestDatabase_TemplateGet(t *testing.T) { func TestDatabase_TemplateGet(t *testing.T) {
t.Parallel()
type args struct { type args struct {
id string id string
} }
@@ -94,7 +108,10 @@ func TestDatabase_TemplateGet(t *testing.T) {
{name: "Not existing", args: args{id: "foobar"}, wantErr: true}, {name: "Not existing", args: args{id: "foobar"}, wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -108,6 +125,7 @@ func TestDatabase_TemplateGet(t *testing.T) {
got, err := db.TemplateGet(test.Context(), tt.args.id) got, err := db.TemplateGet(test.Context(), tt.args.id)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("TemplateGet() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("TemplateGet() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
@@ -120,6 +138,8 @@ func TestDatabase_TemplateGet(t *testing.T) {
} }
func TestDatabase_TemplateList(t *testing.T) { func TestDatabase_TemplateList(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
want []*model.TicketTemplateResponse want []*model.TicketTemplateResponse
@@ -128,7 +148,10 @@ func TestDatabase_TemplateList(t *testing.T) {
{name: "Normal", want: []*model.TicketTemplateResponse{{ID: "default", Name: "Default", Schema: migrations.DefaultTemplateSchema}, {ID: "template-1", Name: template1.Name, Schema: template1.Schema}}}, {name: "Normal", want: []*model.TicketTemplateResponse{{ID: "default", Name: "Default", Schema: migrations.DefaultTemplateSchema}, {ID: "template-1", Name: template1.Name, Schema: template1.Schema}}},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -142,6 +165,7 @@ func TestDatabase_TemplateList(t *testing.T) {
got, err := db.TemplateList(test.Context()) got, err := db.TemplateList(test.Context())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("TemplateList() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("TemplateList() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
assert.Equal(t, got, tt.want) assert.Equal(t, got, tt.want)
@@ -150,6 +174,8 @@ func TestDatabase_TemplateList(t *testing.T) {
} }
func TestDatabase_TemplateUpdate(t *testing.T) { func TestDatabase_TemplateUpdate(t *testing.T) {
t.Parallel()
type args struct { type args struct {
id string id string
template *model.TicketTemplateForm template *model.TicketTemplateForm
@@ -163,7 +189,10 @@ func TestDatabase_TemplateUpdate(t *testing.T) {
{name: "Not existing", args: args{"foobar", template1}, wantErr: true}, {name: "Not existing", args: args{"foobar", template1}, wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
+41 -23
View File
@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -21,7 +22,7 @@ import (
"github.com/SecurityBrewery/catalyst/index" "github.com/SecurityBrewery/catalyst/index"
) )
func toTicket(ticketForm *model.TicketForm) (interface{}, error) { func toTicket(ticketForm *model.TicketForm) (any, error) {
playbooks, err := toPlaybooks(ticketForm.Playbooks) playbooks, err := toPlaybooks(ticketForm.Playbooks)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -65,8 +66,9 @@ func toTicket(ticketForm *model.TicketForm) (interface{}, error) {
ticket.Status = "open" ticket.Status = "open"
} }
if ticketForm.ID != nil { if ticketForm.ID != nil {
return &busdb.Keyed{Key: strconv.FormatInt(*ticketForm.ID, 10), Doc: ticket}, nil return &busdb.Keyed[model.Ticket]{Key: strconv.FormatInt(*ticketForm.ID, 10), Doc: ticket}, nil
} }
return ticket, nil return ticket, nil
} }
@@ -79,6 +81,7 @@ func toTicketResponses(tickets []*model.TicketSimpleResponse) ([]*model.TicketRe
} }
extendedTickets = append(extendedTickets, tr) extendedTickets = append(extendedTickets, tr)
} }
return extendedTickets, nil return extendedTickets, nil
} }
@@ -167,6 +170,7 @@ func toPlaybookResponses(playbooks map[string]*model.Playbook) (map[string]*mode
return nil, err return nil, err
} }
} }
return pr, nil return pr, nil
} }
@@ -195,6 +199,7 @@ func toPlaybookResponse(playbook *model.Playbook) (*model.PlaybookResponse, erro
re.Tasks[taskID] = rootTask re.Tasks[taskID] = rootTask
i++ i++
} }
return re, nil return re, nil
} }
@@ -204,7 +209,7 @@ func (db *Database) TicketBatchCreate(ctx context.Context, ticketForms []*model.
return nil, err return nil, err
} }
var dbTickets []interface{} var dbTickets []any
for _, ticketForm := range ticketForms { for _, ticketForm := range ticketForms {
ticket, err := toTicket(ticketForm) ticket, err := toTicket(ticketForm)
if err != nil { if err != nil {
@@ -231,7 +236,7 @@ func (db *Database) TicketBatchCreate(ctx context.Context, ticketForms []*model.
LET noiddoc = UNSET(keyeddoc, "id") LET noiddoc = UNSET(keyeddoc, "id")
INSERT noiddoc INTO @@collection INSERT noiddoc INTO @@collection
RETURN NEW` RETURN NEW`
apiTickets, _, err := db.ticketListQuery(ctx, query, mergeMaps(map[string]interface{}{ apiTickets, _, err := db.ticketListQuery(ctx, query, mergeMaps(map[string]any{
"tickets": dbTickets, "tickets": dbTickets,
}, ticketFilterVars), busdb.CreateOperation) }, ticketFilterVars), busdb.CreateOperation)
if err != nil { if err != nil {
@@ -247,7 +252,11 @@ func (db *Database) TicketBatchCreate(ctx context.Context, ticketForms []*model.
ids = append(ids, driver.NewDocumentID(TicketCollectionName, fmt.Sprint(apiTicket.ID))) ids = append(ids, driver.NewDocumentID(TicketCollectionName, fmt.Sprint(apiTicket.ID)))
} }
go db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryUpdated) go func() {
if err := db.bus.PublishDatabaseUpdate(ids, bus.DatabaseEntryUpdated); err != nil {
log.Println(err)
}
}()
ticketResponses, err := toTicketResponses(apiTickets) ticketResponses, err := toTicketResponses(apiTickets)
if err != nil { if err != nil {
@@ -294,6 +303,7 @@ func batchIndex(index *index.Index, tickets []*model.TicketSimpleResponse) error
} }
} }
wg.Wait() wg.Wait()
return nil return nil
} }
@@ -306,9 +316,9 @@ func (db *Database) TicketGet(ctx context.Context, ticketID int64) (*model.Ticke
return db.ticketGetQuery(ctx, ticketID, `LET d = DOCUMENT(@@collection, @ID) `+ticketFilterQuery+` RETURN d`, ticketFilterVars, busdb.ReadOperation) return db.ticketGetQuery(ctx, ticketID, `LET d = DOCUMENT(@@collection, @ID) `+ticketFilterQuery+` RETURN d`, ticketFilterVars, busdb.ReadOperation)
} }
func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query string, bindVars map[string]interface{}, operation *busdb.Operation) (*model.TicketWithTickets, error) { func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query string, bindVars map[string]any, operation *busdb.Operation) (*model.TicketWithTickets, error) {
if bindVars == nil { if bindVars == nil {
bindVars = map[string]interface{}{} bindVars = map[string]any{}
} }
bindVars["@collection"] = TicketCollectionName bindVars["@collection"] = TicketCollectionName
if ticketID != 0 { if ticketID != 0 {
@@ -350,7 +360,7 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
RETURN d` RETURN d`
outTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]interface{}{ outTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]any{
"ID": fmt.Sprint(ticketID), "ID": fmt.Sprint(ticketID),
"graph": TicketArtifactsGraphName, "graph": TicketArtifactsGraphName,
"@tickets": TicketCollectionName, "@tickets": TicketCollectionName,
@@ -368,7 +378,7 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
RETURN d` RETURN d`
inTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]interface{}{ inTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]any{
"ID": fmt.Sprint(ticketID), "ID": fmt.Sprint(ticketID),
"graph": TicketArtifactsGraphName, "graph": TicketArtifactsGraphName,
"@tickets": TicketCollectionName, "@tickets": TicketCollectionName,
@@ -387,7 +397,7 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
FOR a IN NOT_NULL(d.artifacts, []) FOR a IN NOT_NULL(d.artifacts, [])
FILTER POSITION(@artifacts, a.name) FILTER POSITION(@artifacts, a.name)
RETURN d` RETURN d`
sameArtifactTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]interface{}{ sameArtifactTickets, _, err := db.ticketListQuery(ctx, ticketsQuery, mergeMaps(map[string]any{
"ID": fmt.Sprint(ticketID), "ID": fmt.Sprint(ticketID),
"artifacts": artifactNames, "artifacts": artifactNames,
}, ticketFilterVars), busdb.ReadOperation) }, ticketFilterVars), busdb.ReadOperation)
@@ -395,7 +405,8 @@ func (db *Database) ticketGetQuery(ctx context.Context, ticketID int64, query st
return nil, err return nil, err
} }
tickets := append(outTickets, inTickets...) tickets := outTickets
tickets = append(tickets, inTickets...)
tickets = append(tickets, sameArtifactTickets...) tickets = append(tickets, sameArtifactTickets...)
sort.Slice(tickets, func(i, j int) bool { sort.Slice(tickets, func(i, j int) bool {
return tickets[i].ID < tickets[j].ID return tickets[i].ID < tickets[j].ID
@@ -425,7 +436,8 @@ func (db *Database) TicketUpdate(ctx context.Context, ticketID int64, ticket *mo
REPLACE d WITH @ticket IN @@collection REPLACE d WITH @ticket IN @@collection
RETURN NEW` RETURN NEW`
ticket.Modified = time.Now().UTC() // TODO make setable? ticket.Modified = time.Now().UTC() // TODO make setable?
return db.ticketGetQuery(ctx, ticketID, query, mergeMaps(map[string]interface{}{"ticket": ticket}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, ticketID, query, mergeMaps(map[string]any{"ticket": ticket}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Ids: []driver.DocumentID{ Type: bus.DatabaseEntryUpdated, Ids: []driver.DocumentID{
driver.NewDocumentID(TicketCollectionName, strconv.FormatInt(ticketID, 10)), driver.NewDocumentID(TicketCollectionName, strconv.FormatInt(ticketID, 10)),
}, },
@@ -447,15 +459,15 @@ func (db *Database) TicketDelete(ctx context.Context, ticketID int64) error {
} }
func (db *Database) TicketList(ctx context.Context, ticketType string, query string, sorts []string, desc []bool, offset, count int64) (*model.TicketList, error) { func (db *Database) TicketList(ctx context.Context, ticketType string, query string, sorts []string, desc []bool, offset, count int64) (*model.TicketList, error) {
binVars := map[string]interface{}{} binVars := map[string]any{}
var typeString = "" typeString := ""
if ticketType != "" { if ticketType != "" {
typeString = "FILTER d.type == @type " typeString = "FILTER d.type == @type "
binVars["type"] = ticketType binVars["type"] = ticketType
} }
var filterString = "" filterString := ""
if query != "" { if query != "" {
parser := &caql.Parser{Searcher: db.Index, Prefix: "d."} parser := &caql.Parser{Searcher: db.Index, Prefix: "d."}
queryTree, err := parser.Parse(query) queryTree, err := parser.Parse(query)
@@ -493,6 +505,7 @@ func (db *Database) TicketList(ctx context.Context, ticketType string, query str
RETURN d` RETURN d`
// RETURN KEEP(d, "_key", "id", "name", "type", "created")` // RETURN KEEP(d, "_key", "id", "name", "type", "created")`
ticketList, _, err := db.ticketListQuery(ctx, q, mergeMaps(binVars, ticketFilterVars), busdb.ReadOperation) ticketList, _, err := db.ticketListQuery(ctx, q, mergeMaps(binVars, ticketFilterVars), busdb.ReadOperation)
return &model.TicketList{ return &model.TicketList{
Count: documentCount, Count: documentCount,
Tickets: ticketList, Tickets: ticketList,
@@ -500,9 +513,9 @@ func (db *Database) TicketList(ctx context.Context, ticketType string, query str
// return map[string]interface{}{"tickets": ticketList, "count": documentCount}, err // return map[string]interface{}{"tickets": ticketList, "count": documentCount}, err
} }
func (db *Database) ticketListQuery(ctx context.Context, query string, bindVars map[string]interface{}, operation *busdb.Operation) ([]*model.TicketSimpleResponse, *model.LogEntry, error) { func (db *Database) ticketListQuery(ctx context.Context, query string, bindVars map[string]any, operation *busdb.Operation) ([]*model.TicketSimpleResponse, *model.LogEntry, error) {
if bindVars == nil { if bindVars == nil {
bindVars = map[string]interface{}{} bindVars = map[string]any{}
} }
bindVars["@collection"] = TicketCollectionName bindVars["@collection"] = TicketCollectionName
@@ -533,9 +546,9 @@ func (db *Database) ticketListQuery(ctx context.Context, query string, bindVars
return docs, logEntry, nil return docs, logEntry, nil
} }
func (db *Database) TicketCount(ctx context.Context, typequery, filterquery string, bindVars map[string]interface{}) (int, error) { func (db *Database) TicketCount(ctx context.Context, typequery, filterquery string, bindVars map[string]any) (int, error) {
if bindVars == nil { if bindVars == nil {
bindVars = map[string]interface{}{} bindVars = map[string]any{}
} }
bindVars["@collection"] = TicketCollectionName bindVars["@collection"] = TicketCollectionName
@@ -555,10 +568,11 @@ func (db *Database) TicketCount(ctx context.Context, typequery, filterquery stri
return 0, err return 0, err
} }
cursor.Close() cursor.Close()
return documentCount, nil return documentCount, nil
} }
func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]interface{}) string { func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]any) string {
sort := "" sort := ""
if len(paramsSort) > 0 { if len(paramsSort) > 0 {
var sorts []string var sorts []string
@@ -572,21 +586,23 @@ func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]inter
} }
sort = "SORT " + strings.Join(sorts, ", ") sort = "SORT " + strings.Join(sorts, ", ")
} }
return sort return sort
} }
func mergeMaps(a map[string]interface{}, b map[string]interface{}) map[string]interface{} { func mergeMaps(a map[string]any, b map[string]any) map[string]any {
merged := map[string]interface{}{} merged := map[string]any{}
for k, v := range a { for k, v := range a {
merged[k] = v merged[k] = v
} }
for k, v := range b { for k, v := range b {
merged[k] = v merged[k] = v
} }
return merged return merged
} }
func validate(e interface{}, schema *gojsonschema.Schema) error { func validate(e any, schema *gojsonschema.Schema) error {
b, err := json.Marshal(e) b, err := json.Marshal(e)
if err != nil { if err != nil {
return err return err
@@ -602,7 +618,9 @@ func validate(e interface{}, schema *gojsonschema.Schema) error {
for _, e := range res.Errors() { for _, e := range res.Errors() {
l = append(l, e.String()) l = append(l, e.String())
} }
return fmt.Errorf("validation failed: %v", strings.Join(l, ", ")) return fmt.Errorf("validation failed: %v", strings.Join(l, ", "))
} }
return nil return nil
} }
+19 -9
View File
@@ -34,7 +34,8 @@ func (db *Database) AddArtifact(ctx context.Context, id int64, artifact *model.A
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
UPDATE d WITH { "modified": @now, "artifacts": PUSH(NOT_NULL(d.artifacts, []), @artifact) } IN @@collection UPDATE d WITH { "modified": @now, "artifacts": PUSH(NOT_NULL(d.artifacts, []), @artifact) } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"artifact": artifact, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"artifact": artifact, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -57,6 +58,7 @@ func inferType(name string) string {
case commonregex.SHA256HexRegex.MatchString(name): case commonregex.SHA256HexRegex.MatchString(name):
return "sha256" return "sha256"
} }
return "unknown" return "unknown"
} }
@@ -73,7 +75,8 @@ func (db *Database) RemoveArtifact(ctx context.Context, id int64, name string) (
LET newartifacts = REMOVE_VALUE(d.artifacts, a) LET newartifacts = REMOVE_VALUE(d.artifacts, a)
UPDATE d WITH { "modified": @now, "artifacts": newartifacts } IN @@collection UPDATE d WITH { "modified": @now, "artifacts": newartifacts } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"name": name, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"name": name, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -91,7 +94,8 @@ func (db *Database) SetTemplate(ctx context.Context, id int64, schema string) (*
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
UPDATE d WITH { "schema": @schema } IN @@collection UPDATE d WITH { "schema": @schema } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"schema": schema}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"schema": schema}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -122,7 +126,8 @@ func (db *Database) AddComment(ctx context.Context, id int64, comment *model.Com
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
UPDATE d WITH { "modified": @now, "comments": PUSH(NOT_NULL(d.comments, []), @comment) } IN @@collection UPDATE d WITH { "modified": @now, "comments": PUSH(NOT_NULL(d.comments, []), @comment) } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"comment": comment, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"comment": comment, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -140,7 +145,8 @@ func (db *Database) RemoveComment(ctx context.Context, id int64, commentID int64
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
UPDATE d WITH { "modified": @now, "comments": REMOVE_NTH(d.comments, @commentID) } IN @@collection UPDATE d WITH { "modified": @now, "comments": REMOVE_NTH(d.comments, @commentID) } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"commentID": commentID, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"commentID": commentID, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -158,7 +164,8 @@ func (db *Database) SetReferences(ctx context.Context, id int64, references []*m
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
UPDATE d WITH { "modified": @now, "references": @references } IN @@collection UPDATE d WITH { "modified": @now, "references": @references } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"references": references, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"references": references, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -176,7 +183,8 @@ func (db *Database) AddFile(ctx context.Context, id int64, file *model.File) (*m
` + ticketFilterQuery + ` ` + ticketFilterQuery + `
UPDATE d WITH { "modified": @now, "files": APPEND(NOT_NULL(d.files, []), [@file]) } IN @@collection UPDATE d WITH { "modified": @now, "files": APPEND(NOT_NULL(d.files, []), [@file]) } IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{"file": file, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{"file": file, "now": time.Now().UTC()}, ticketFilterVars), &busdb.Operation{
Type: bus.DatabaseEntryUpdated, Type: bus.DatabaseEntryUpdated,
Ids: []driver.DocumentID{ Ids: []driver.DocumentID{
driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)), driver.DocumentID(fmt.Sprintf("%s/%d", TicketCollectionName, id)),
@@ -213,7 +221,7 @@ func (db *Database) AddTicketPlaybook(ctx context.Context, id int64, playbookTem
LET newticket = MERGE(d, { "modified": @now, "playbooks": newplaybooks }) LET newticket = MERGE(d, { "modified": @now, "playbooks": newplaybooks })
REPLACE d WITH newticket IN @@collection REPLACE d WITH newticket IN @@collection
RETURN NEW` RETURN NEW`
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{ ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"playbook": pb, "playbook": pb,
"playbookID": findName(parentTicket.Playbooks, playbookID), "playbookID": findName(parentTicket.Playbooks, playbookID),
"now": time.Now().UTC(), "now": time.Now().UTC(),
@@ -258,6 +266,7 @@ func runRootTask(ticket *model.TicketResponse, playbookID string, db *Database)
} }
runNextTasks(ticket.ID, playbookID, root.Next, root.Data, ticket, db) runNextTasks(ticket.ID, playbookID, root.Next, root.Data, ticket, db)
return nil return nil
} }
@@ -273,7 +282,8 @@ func (db *Database) RemoveTicketPlaybook(ctx context.Context, id int64, playbook
LET newplaybooks = UNSET(d.playbooks, @playbookID) LET newplaybooks = UNSET(d.playbooks, @playbookID)
REPLACE d WITH MERGE(d, { "modified": @now, "playbooks": newplaybooks }) IN @@collection REPLACE d WITH MERGE(d, { "modified": @now, "playbooks": newplaybooks }) IN @@collection
RETURN NEW` RETURN NEW`
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{
return db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"playbookID": playbookID, "playbookID": playbookID,
"now": time.Now().UTC(), "now": time.Now().UTC(),
}, ticketFilterVars), &busdb.Operation{ }, ticketFilterVars), &busdb.Operation{
+7 -6
View File
@@ -41,7 +41,7 @@ func (db *Database) TaskGet(ctx context.Context, id int64, playbookID string, ta
}, nil }, nil
} }
func (db *Database) TaskComplete(ctx context.Context, id int64, playbookID string, taskID string, data interface{}) (*model.TicketWithTickets, error) { func (db *Database) TaskComplete(ctx context.Context, id int64, playbookID string, taskID string, data any) (*model.TicketWithTickets, error) {
inc, err := db.TicketGet(ctx, id) inc, err := db.TicketGet(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -68,7 +68,7 @@ func (db *Database) TaskComplete(ctx context.Context, id int64, playbookID strin
UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection
RETURN NEW` RETURN NEW`
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{ ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"playbookID": playbookID, "playbookID": playbookID,
"taskID": taskID, "taskID": taskID,
"data": data, "data": data,
@@ -130,7 +130,7 @@ func (db *Database) TaskUpdateOwner(ctx context.Context, id int64, playbookID st
UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection
RETURN NEW` RETURN NEW`
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{ ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"playbookID": playbookID, "playbookID": playbookID,
"taskID": taskID, "taskID": taskID,
"owner": owner, "owner": owner,
@@ -148,7 +148,7 @@ func (db *Database) TaskUpdateOwner(ctx context.Context, id int64, playbookID st
return ticket, nil return ticket, nil
} }
func (db *Database) TaskUpdateData(ctx context.Context, id int64, playbookID string, taskID string, data map[string]interface{}) (*model.TicketWithTickets, error) { func (db *Database) TaskUpdateData(ctx context.Context, id int64, playbookID string, taskID string, data map[string]any) (*model.TicketWithTickets, error) {
ticketFilterQuery, ticketFilterVars, err := db.Hooks.TicketWriteFilter(ctx) ticketFilterQuery, ticketFilterVars, err := db.Hooks.TicketWriteFilter(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -165,7 +165,7 @@ func (db *Database) TaskUpdateData(ctx context.Context, id int64, playbookID str
UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection UPDATE d WITH { "modified": @now, "playbooks": newplaybooks } IN @@collection
RETURN NEW` RETURN NEW`
ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]interface{}{ ticket, err := db.ticketGetQuery(ctx, id, query, mergeMaps(map[string]any{
"playbookID": playbookID, "playbookID": playbookID,
"taskID": taskID, "taskID": taskID,
"data": data, "data": data,
@@ -198,7 +198,7 @@ func (db *Database) TaskRun(ctx context.Context, id int64, playbookID string, ta
return nil return nil
} }
func runNextTasks(id int64, playbookID string, next map[string]string, data interface{}, ticket *model.TicketResponse, db *Database) { func runNextTasks(id int64, playbookID string, next map[string]string, data any, ticket *model.TicketResponse, db *Database) {
for nextTaskID, requirement := range next { for nextTaskID, requirement := range next {
nextTask := ticket.Playbooks[playbookID].Tasks[nextTaskID] nextTask := ticket.Playbooks[playbookID].Tasks[nextTaskID]
if nextTask.Type == model.TaskTypeAutomation { if nextTask.Type == model.TaskTypeAutomation {
@@ -220,5 +220,6 @@ func runTask(ticketID int64, playbookID string, taskID string, task *model.TaskR
msgContext := &model.Context{Playbook: playbook, Task: task, Ticket: ticket} msgContext := &model.Context{Playbook: playbook, Task: task, Ticket: ticket}
origin := &model.Origin{TaskOrigin: &model.TaskOrigin{TaskId: taskID, PlaybookId: playbookID, TicketId: ticketID}} origin := &model.Origin{TaskOrigin: &model.TaskOrigin{TaskId: taskID, PlaybookId: playbookID, TicketId: ticketID}}
jobID := uuid.NewString() jobID := uuid.NewString()
return publishJobMapping(jobID, *task.Automation, msgContext, origin, task.Payload, db) return publishJobMapping(jobID, *task.Automation, msgContext, origin, task.Payload, db)
} }
+2 -1
View File
@@ -75,12 +75,13 @@ func (db *Database) TicketTypeUpdate(ctx context.Context, id string, tickettype
func (db *Database) TicketTypeDelete(ctx context.Context, id string) error { func (db *Database) TicketTypeDelete(ctx context.Context, id string) error {
_, err := db.tickettypeCollection.RemoveDocument(ctx, id) _, err := db.tickettypeCollection.RemoveDocument(ctx, id)
return err return err
} }
func (db *Database) TicketTypeList(ctx context.Context) ([]*model.TicketTypeResponse, error) { func (db *Database) TicketTypeList(ctx context.Context) ([]*model.TicketTypeResponse, error) {
query := "FOR d IN @@collection RETURN d" query := "FOR d IN @@collection RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": TicketTypeCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": TicketTypeCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+6 -2
View File
@@ -28,6 +28,7 @@ func generateKey() string {
for i := range b { for i := range b {
b[i] = letters[rand.Intn(len(letters))] b[i] = letters[rand.Intn(len(letters))]
} }
return string(b) return string(b)
} }
@@ -78,8 +79,10 @@ func (db *Database) UserGetOrCreate(ctx context.Context, newUser *model.UserForm
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &model.UserResponse{ID: newUser.ID, Roles: newUser.Roles, Blocked: newUser.Blocked}, nil return &model.UserResponse{ID: newUser.ID, Roles: newUser.Roles, Blocked: newUser.Blocked}, nil
} }
return user, nil return user, nil
} }
@@ -132,12 +135,13 @@ func (db *Database) UserGet(ctx context.Context, id string) (*model.UserResponse
func (db *Database) UserDelete(ctx context.Context, id string) error { func (db *Database) UserDelete(ctx context.Context, id string) error {
_, err := db.userCollection.RemoveDocument(ctx, id) _, err := db.userCollection.RemoveDocument(ctx, id)
return err return err
} }
func (db *Database) UserList(ctx context.Context) ([]*model.UserResponse, error) { func (db *Database) UserList(ctx context.Context) ([]*model.UserResponse, error) {
query := "FOR d IN @@collection RETURN d" query := "FOR d IN @@collection RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": UserCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": UserCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -163,7 +167,7 @@ func (db *Database) UserByHash(ctx context.Context, sha256 string) (*model.UserR
FILTER d.sha256 == @sha256 FILTER d.sha256 == @sha256
RETURN d` RETURN d`
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": UserCollectionName, "sha256": sha256}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": UserCollectionName, "sha256": sha256}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+3 -1
View File
@@ -29,6 +29,7 @@ func (db *Database) UserDataCreate(ctx context.Context, id string, userdata *mod
} }
_, err := db.userdataCollection.CreateDocument(ctx, ctx, id, userdata) _, err := db.userdataCollection.CreateDocument(ctx, ctx, id, userdata)
return err return err
} }
@@ -37,6 +38,7 @@ func (db *Database) UserDataGetOrCreate(ctx context.Context, id string, newUserD
if err != nil { if err != nil {
return toUserDataResponse(id, newUserData), db.UserDataCreate(ctx, id, newUserData) return toUserDataResponse(id, newUserData), db.UserDataCreate(ctx, id, newUserData)
} }
return setting, nil return setting, nil
} }
@@ -52,7 +54,7 @@ func (db *Database) UserDataGet(ctx context.Context, id string) (*model.UserData
func (db *Database) UserDataList(ctx context.Context) ([]*model.UserDataResponse, error) { func (db *Database) UserDataList(ctx context.Context) ([]*model.UserDataResponse, error) {
query := "FOR d IN @@collection SORT d.username ASC RETURN d" query := "FOR d IN @@collection SORT d.username ASC RETURN d"
cursor, _, err := db.Query(ctx, query, map[string]interface{}{"@collection": UserDataCollectionName}, busdb.ReadOperation) cursor, _, err := db.Query(ctx, query, map[string]any{"@collection": UserDataCollectionName}, busdb.ReadOperation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+22
View File
@@ -22,6 +22,8 @@ var bobResponse = &model.UserDataResponse{
} }
func TestDatabase_UserDataCreate(t *testing.T) { func TestDatabase_UserDataCreate(t *testing.T) {
t.Parallel()
type args struct { type args struct {
id string id string
setting *model.UserData setting *model.UserData
@@ -37,7 +39,10 @@ func TestDatabase_UserDataCreate(t *testing.T) {
{name: "Only settingname", args: args{id: "bob"}, wantErr: true}, {name: "Only settingname", args: args{id: "bob"}, wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -52,6 +57,8 @@ func TestDatabase_UserDataCreate(t *testing.T) {
} }
func TestDatabase_UserDataGet(t *testing.T) { func TestDatabase_UserDataGet(t *testing.T) {
t.Parallel()
type args struct { type args struct {
id string id string
} }
@@ -65,7 +72,10 @@ func TestDatabase_UserDataGet(t *testing.T) {
{name: "Not existing", args: args{id: "foo"}, wantErr: true}, {name: "Not existing", args: args{id: "foo"}, wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -79,6 +89,7 @@ func TestDatabase_UserDataGet(t *testing.T) {
got, err := db.UserDataGet(test.Context(), tt.args.id) got, err := db.UserDataGet(test.Context(), tt.args.id)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UserDataGet() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UserDataGet() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
@@ -91,6 +102,8 @@ func TestDatabase_UserDataGet(t *testing.T) {
} }
func TestDatabase_UserDataList(t *testing.T) { func TestDatabase_UserDataList(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
want []*model.UserDataResponse want []*model.UserDataResponse
@@ -99,7 +112,10 @@ func TestDatabase_UserDataList(t *testing.T) {
{name: "Normal list", want: []*model.UserDataResponse{bobResponse}}, {name: "Normal list", want: []*model.UserDataResponse{bobResponse}},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -113,6 +129,7 @@ func TestDatabase_UserDataList(t *testing.T) {
got, err := db.UserDataList(test.Context()) got, err := db.UserDataList(test.Context())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UserDataList() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UserDataList() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
@@ -122,6 +139,8 @@ func TestDatabase_UserDataList(t *testing.T) {
} }
func TestDatabase_UserDataUpdate(t *testing.T) { func TestDatabase_UserDataUpdate(t *testing.T) {
t.Parallel()
type args struct { type args struct {
id string id string
setting *model.UserData setting *model.UserData
@@ -135,7 +154,10 @@ func TestDatabase_UserDataUpdate(t *testing.T) {
{name: "Not existing", args: args{id: "foo"}, wantErr: true}, {name: "Not existing", args: args{id: "foo"}, wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, db, cleanup, err := test.DB(t) _, _, _, _, _, db, cleanup, err := test.DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
+12 -1
View File
@@ -30,11 +30,13 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
ticketID := chi.URLParam(r, "ticketID") ticketID := chi.URLParam(r, "ticketID")
if ticketID == "" { if ticketID == "" {
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given")) api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given"))
return return
} }
if err := storage.CreateBucket(client, ticketID); err != nil { if err := storage.CreateBucket(client, ticketID); err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create bucket: %w", err)) api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create bucket: %w", err))
return return
} }
@@ -50,6 +52,7 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
}) })
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create tusd handler: %w", err)) api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create tusd handler: %w", err))
return return
} }
@@ -73,6 +76,7 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
doc, err := db.AddFile(ctx, id, file) doc, err := db.AddFile(ctx, id, file)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
@@ -92,7 +96,6 @@ func tusdUpload(db *database.Database, bus *bus.Bus, client *s3.S3, external str
default: default:
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("unknown method")) api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("unknown method"))
} }
} }
} }
@@ -101,18 +104,21 @@ func upload(db *database.Database, client *s3.S3, uploader *s3manager.Uploader)
ticketID := chi.URLParam(r, "ticketID") ticketID := chi.URLParam(r, "ticketID")
if ticketID == "" { if ticketID == "" {
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given")) api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given"))
return return
} }
file, header, err := r.FormFile("file") file, header, err := r.FormFile("file")
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, err) api.JSONErrorStatus(w, http.StatusBadRequest, err)
return return
} }
defer file.Close() defer file.Close()
if err := storage.CreateBucket(client, ticketID); err != nil { if err := storage.CreateBucket(client, ticketID); err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create bucket: %w", err)) api.JSONErrorStatus(w, http.StatusBadRequest, fmt.Errorf("could not create bucket: %w", err))
return return
} }
@@ -123,12 +129,14 @@ func upload(db *database.Database, client *s3.S3, uploader *s3manager.Uploader)
}) })
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, err) api.JSONErrorStatus(w, http.StatusBadRequest, err)
return return
} }
id, err := strconv.ParseInt(ticketID, 10, 64) id, err := strconv.ParseInt(ticketID, 10, 64)
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, err) api.JSONErrorStatus(w, http.StatusBadRequest, err)
return return
} }
@@ -138,6 +146,7 @@ func upload(db *database.Database, client *s3.S3, uploader *s3manager.Uploader)
}) })
if err != nil { if err != nil {
api.JSONErrorStatus(w, http.StatusBadRequest, err) api.JSONErrorStatus(w, http.StatusBadRequest, err)
return return
} }
} }
@@ -148,12 +157,14 @@ func download(downloader *s3manager.Downloader) http.HandlerFunc {
ticketID := chi.URLParam(r, "ticketID") ticketID := chi.URLParam(r, "ticketID")
if ticketID == "" { if ticketID == "" {
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given")) api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("ticketID not given"))
return return
} }
key := chi.URLParam(r, "key") key := chi.URLParam(r, "key")
if key == "" { if key == "" {
api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("key not given")) api.JSONErrorStatus(w, http.StatusBadRequest, errors.New("key not given"))
return return
} }
+3 -3
View File
@@ -118,7 +118,7 @@ func parseQueryOptionalBoolArray(r *http.Request, key string) ([]bool, error) {
return parseQueryBoolArray(r, key) return parseQueryBoolArray(r, key)
} }
func parseBody(b []byte, i interface{}) error { func parseBody(b []byte, i any) error {
dec := json.NewDecoder(bytes.NewBuffer(b)) dec := json.NewDecoder(bytes.NewBuffer(b))
err := dec.Decode(i) err := dec.Decode(i)
if err != nil { if err != nil {
@@ -137,7 +137,7 @@ func JSONErrorStatus(w http.ResponseWriter, status int, err error) {
w.Write(b) w.Write(b)
} }
func response(w http.ResponseWriter, v interface{}, err error) { func response(w http.ResponseWriter, v any, err error) {
if err != nil { if err != nil {
var httpError *HTTPError var httpError *HTTPError
if errors.As(err, &httpError) { if errors.As(err, &httpError) {
@@ -172,7 +172,7 @@ func validateSchema(body []byte, schema *gojsonschema.Schema, w http.ResponseWri
validationErrors = append(validationErrors, valdiationError.String()) validationErrors = append(validationErrors, valdiationError.String())
} }
b, _ := json.Marshal(map[string]interface{}{"error": "wrong input", "errors": validationErrors}) b, _ := json.Marshal(map[string]any{"error": "wrong input", "errors": validationErrors})
w.Write(b) w.Write(b)
return true return true
} }
+5 -5
View File
@@ -19,7 +19,7 @@ type Service interface {
CurrentUser(context.Context) (*model.UserResponse, error) CurrentUser(context.Context) (*model.UserResponse, error)
CurrentUserData(context.Context) (*model.UserDataResponse, error) CurrentUserData(context.Context) (*model.UserDataResponse, error)
UpdateCurrentUserData(context.Context, *model.UserData) (*model.UserDataResponse, error) UpdateCurrentUserData(context.Context, *model.UserData) (*model.UserDataResponse, error)
DashboardData(context.Context, string, *string) (map[string]interface{}, error) DashboardData(context.Context, string, *string) (map[string]any, error)
ListDashboards(context.Context) ([]*model.DashboardResponse, error) ListDashboards(context.Context) ([]*model.DashboardResponse, error)
CreateDashboard(context.Context, *model.Dashboard) (*model.DashboardResponse, error) CreateDashboard(context.Context, *model.Dashboard) (*model.DashboardResponse, error)
GetDashboard(context.Context, string) (*model.DashboardResponse, error) GetDashboard(context.Context, string) (*model.DashboardResponse, error)
@@ -60,8 +60,8 @@ type Service interface {
RemoveComment(context.Context, int64, int) (*model.TicketWithTickets, error) RemoveComment(context.Context, int64, int) (*model.TicketWithTickets, error)
AddTicketPlaybook(context.Context, int64, *model.PlaybookTemplateForm) (*model.TicketWithTickets, error) AddTicketPlaybook(context.Context, int64, *model.PlaybookTemplateForm) (*model.TicketWithTickets, error)
RemoveTicketPlaybook(context.Context, int64, string) (*model.TicketWithTickets, error) RemoveTicketPlaybook(context.Context, int64, string) (*model.TicketWithTickets, error)
SetTaskData(context.Context, int64, string, string, map[string]interface{}) (*model.TicketWithTickets, error) SetTaskData(context.Context, int64, string, string, map[string]any) (*model.TicketWithTickets, error)
CompleteTask(context.Context, int64, string, string, map[string]interface{}) (*model.TicketWithTickets, error) CompleteTask(context.Context, int64, string, string, map[string]any) (*model.TicketWithTickets, error)
SetTaskOwner(context.Context, int64, string, string, string) (*model.TicketWithTickets, error) SetTaskOwner(context.Context, int64, string, string, string) (*model.TicketWithTickets, error)
RunTask(context.Context, int64, string, string) error RunTask(context.Context, int64, string, string) error
SetReferences(context.Context, int64, *model.ReferenceArray) (*model.TicketWithTickets, error) SetReferences(context.Context, int64, *model.ReferenceArray) (*model.TicketWithTickets, error)
@@ -901,7 +901,7 @@ func (s *server) setTaskDataHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
var dataP map[string]interface{} var dataP map[string]any
if err := parseBody(body, &dataP); err != nil { if err := parseBody(body, &dataP); err != nil {
JSONError(w, err) JSONError(w, err)
return return
@@ -928,7 +928,7 @@ func (s *server) completeTaskHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
var dataP map[string]interface{} var dataP map[string]any
if err := parseBody(body, &dataP); err != nil { if err := parseBody(body, &dataP); err != nil {
JSONError(w, err) JSONError(w, err)
return return
File diff suppressed because one or more lines are too long
+90 -90
View File
@@ -251,14 +251,14 @@ type DashboardResponse struct {
} }
type Enrichment struct { type Enrichment struct {
Created time.Time `json:"created"` Created time.Time `json:"created"`
Data map[string]interface{} `json:"data"` Data map[string]any `json:"data"`
Name string `json:"name"` Name string `json:"name"`
} }
type EnrichmentForm struct { type EnrichmentForm struct {
Data map[string]interface{} `json:"data"` Data map[string]any `json:"data"`
Name string `json:"name"` Name string `json:"name"`
} }
type File struct { type File struct {
@@ -267,39 +267,39 @@ type File struct {
} }
type Job struct { type Job struct {
Automation string `json:"automation"` Automation string `json:"automation"`
Container *string `json:"container,omitempty"` Container *string `json:"container,omitempty"`
Log *string `json:"log,omitempty"` Log *string `json:"log,omitempty"`
Origin *Origin `json:"origin,omitempty"` Origin *Origin `json:"origin,omitempty"`
Output map[string]interface{} `json:"output,omitempty"` Output map[string]any `json:"output,omitempty"`
Payload interface{} `json:"payload,omitempty"` Payload any `json:"payload,omitempty"`
Running bool `json:"running"` Running bool `json:"running"`
Status string `json:"status"` Status string `json:"status"`
} }
type JobForm struct { type JobForm struct {
Automation string `json:"automation"` Automation string `json:"automation"`
Origin *Origin `json:"origin,omitempty"` Origin *Origin `json:"origin,omitempty"`
Payload interface{} `json:"payload,omitempty"` Payload any `json:"payload,omitempty"`
} }
type JobResponse struct { type JobResponse struct {
Automation string `json:"automation"` Automation string `json:"automation"`
Container *string `json:"container,omitempty"` Container *string `json:"container,omitempty"`
ID string `json:"id"` ID string `json:"id"`
Log *string `json:"log,omitempty"` Log *string `json:"log,omitempty"`
Origin *Origin `json:"origin,omitempty"` Origin *Origin `json:"origin,omitempty"`
Output map[string]interface{} `json:"output,omitempty"` Output map[string]any `json:"output,omitempty"`
Payload interface{} `json:"payload,omitempty"` Payload any `json:"payload,omitempty"`
Status string `json:"status"` Status string `json:"status"`
} }
type JobUpdate struct { type JobUpdate struct {
Container *string `json:"container,omitempty"` Container *string `json:"container,omitempty"`
Log *string `json:"log,omitempty"` Log *string `json:"log,omitempty"`
Output map[string]interface{} `json:"output,omitempty"` Output map[string]any `json:"output,omitempty"`
Running bool `json:"running"` Running bool `json:"running"`
Status string `json:"status"` Status string `json:"status"`
} }
type LogEntry struct { type LogEntry struct {
@@ -312,7 +312,7 @@ type LogEntry struct {
type Message struct { type Message struct {
Context *Context `json:"context,omitempty"` Context *Context `json:"context,omitempty"`
Payload interface{} `json:"payload,omitempty"` Payload any `json:"payload,omitempty"`
Secrets map[string]string `json:"secrets,omitempty"` Secrets map[string]string `json:"secrets,omitempty"`
} }
@@ -385,18 +385,18 @@ type Statistics struct {
} }
type Task struct { type Task struct {
Automation *string `json:"automation,omitempty"` Automation *string `json:"automation,omitempty"`
Closed *time.Time `json:"closed,omitempty"` Closed *time.Time `json:"closed,omitempty"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Data map[string]interface{} `json:"data,omitempty"` Data map[string]any `json:"data,omitempty"`
Done bool `json:"done"` Done bool `json:"done"`
Join *bool `json:"join,omitempty"` Join *bool `json:"join,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Next map[string]string `json:"next,omitempty"` Next map[string]string `json:"next,omitempty"`
Owner *string `json:"owner,omitempty"` Owner *string `json:"owner,omitempty"`
Payload map[string]string `json:"payload,omitempty"` Payload map[string]string `json:"payload,omitempty"`
Schema map[string]interface{} `json:"schema,omitempty"` Schema map[string]any `json:"schema,omitempty"`
Type string `json:"type"` Type string `json:"type"`
} }
type TaskOrigin struct { type TaskOrigin struct {
@@ -406,20 +406,20 @@ type TaskOrigin struct {
} }
type TaskResponse struct { type TaskResponse struct {
Active bool `json:"active"` Active bool `json:"active"`
Automation *string `json:"automation,omitempty"` Automation *string `json:"automation,omitempty"`
Closed *time.Time `json:"closed,omitempty"` Closed *time.Time `json:"closed,omitempty"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Data map[string]interface{} `json:"data,omitempty"` Data map[string]any `json:"data,omitempty"`
Done bool `json:"done"` Done bool `json:"done"`
Join *bool `json:"join,omitempty"` Join *bool `json:"join,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Next map[string]string `json:"next,omitempty"` Next map[string]string `json:"next,omitempty"`
Order int64 `json:"order"` Order int64 `json:"order"`
Owner *string `json:"owner,omitempty"` Owner *string `json:"owner,omitempty"`
Payload map[string]string `json:"payload,omitempty"` Payload map[string]string `json:"payload,omitempty"`
Schema map[string]interface{} `json:"schema,omitempty"` Schema map[string]any `json:"schema,omitempty"`
Type string `json:"type"` Type string `json:"type"`
} }
type TaskWithContext struct { type TaskWithContext struct {
@@ -432,28 +432,28 @@ type TaskWithContext struct {
} }
type Ticket struct { type Ticket struct {
Artifacts []*Artifact `json:"artifacts,omitempty"` Artifacts []*Artifact `json:"artifacts,omitempty"`
Comments []*Comment `json:"comments,omitempty"` Comments []*Comment `json:"comments,omitempty"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Details map[string]interface{} `json:"details,omitempty"` Details map[string]any `json:"details,omitempty"`
Files []*File `json:"files,omitempty"` Files []*File `json:"files,omitempty"`
Modified time.Time `json:"modified"` Modified time.Time `json:"modified"`
Name string `json:"name"` Name string `json:"name"`
Owner *string `json:"owner,omitempty"` Owner *string `json:"owner,omitempty"`
Playbooks map[string]*Playbook `json:"playbooks,omitempty"` Playbooks map[string]*Playbook `json:"playbooks,omitempty"`
Read []string `json:"read,omitempty"` Read []string `json:"read,omitempty"`
References []*Reference `json:"references,omitempty"` References []*Reference `json:"references,omitempty"`
Schema string `json:"schema"` Schema string `json:"schema"`
Status string `json:"status"` Status string `json:"status"`
Type string `json:"type"` Type string `json:"type"`
Write []string `json:"write,omitempty"` Write []string `json:"write,omitempty"`
} }
type TicketForm struct { type TicketForm struct {
Artifacts []*Artifact `json:"artifacts,omitempty"` Artifacts []*Artifact `json:"artifacts,omitempty"`
Comments []*Comment `json:"comments,omitempty"` Comments []*Comment `json:"comments,omitempty"`
Created *time.Time `json:"created,omitempty"` Created *time.Time `json:"created,omitempty"`
Details map[string]interface{} `json:"details,omitempty"` Details map[string]any `json:"details,omitempty"`
Files []*File `json:"files,omitempty"` Files []*File `json:"files,omitempty"`
ID *int64 `json:"id,omitempty"` ID *int64 `json:"id,omitempty"`
Modified *time.Time `json:"modified,omitempty"` Modified *time.Time `json:"modified,omitempty"`
@@ -479,7 +479,7 @@ type TicketResponse struct {
Artifacts []*Artifact `json:"artifacts,omitempty"` Artifacts []*Artifact `json:"artifacts,omitempty"`
Comments []*Comment `json:"comments,omitempty"` Comments []*Comment `json:"comments,omitempty"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Details map[string]interface{} `json:"details,omitempty"` Details map[string]any `json:"details,omitempty"`
Files []*File `json:"files,omitempty"` Files []*File `json:"files,omitempty"`
ID int64 `json:"id"` ID int64 `json:"id"`
Modified time.Time `json:"modified"` Modified time.Time `json:"modified"`
@@ -495,22 +495,22 @@ type TicketResponse struct {
} }
type TicketSimpleResponse struct { type TicketSimpleResponse struct {
Artifacts []*Artifact `json:"artifacts,omitempty"` Artifacts []*Artifact `json:"artifacts,omitempty"`
Comments []*Comment `json:"comments,omitempty"` Comments []*Comment `json:"comments,omitempty"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Details map[string]interface{} `json:"details,omitempty"` Details map[string]any `json:"details,omitempty"`
Files []*File `json:"files,omitempty"` Files []*File `json:"files,omitempty"`
ID int64 `json:"id"` ID int64 `json:"id"`
Modified time.Time `json:"modified"` Modified time.Time `json:"modified"`
Name string `json:"name"` Name string `json:"name"`
Owner *string `json:"owner,omitempty"` Owner *string `json:"owner,omitempty"`
Playbooks map[string]*Playbook `json:"playbooks,omitempty"` Playbooks map[string]*Playbook `json:"playbooks,omitempty"`
Read []string `json:"read,omitempty"` Read []string `json:"read,omitempty"`
References []*Reference `json:"references,omitempty"` References []*Reference `json:"references,omitempty"`
Schema string `json:"schema"` Schema string `json:"schema"`
Status string `json:"status"` Status string `json:"status"`
Type string `json:"type"` Type string `json:"type"`
Write []string `json:"write,omitempty"` Write []string `json:"write,omitempty"`
} }
type TicketTemplate struct { type TicketTemplate struct {
@@ -560,7 +560,7 @@ type TicketWithTickets struct {
Artifacts []*Artifact `json:"artifacts,omitempty"` Artifacts []*Artifact `json:"artifacts,omitempty"`
Comments []*Comment `json:"comments,omitempty"` Comments []*Comment `json:"comments,omitempty"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Details map[string]interface{} `json:"details,omitempty"` Details map[string]any `json:"details,omitempty"`
Files []*File `json:"files,omitempty"` Files []*File `json:"files,omitempty"`
ID int64 `json:"id"` ID int64 `json:"id"`
Logs []*LogEntry `json:"logs,omitempty"` Logs []*LogEntry `json:"logs,omitempty"`
+61 -21
View File
@@ -1,53 +1,93 @@
module github.com/SecurityBrewery/catalyst module github.com/SecurityBrewery/catalyst
go 1.16 go 1.18
require ( require (
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.5.1 // indirect
github.com/alecthomas/kong v0.2.17 github.com/alecthomas/kong v0.2.17
github.com/alecthomas/kong-yaml v0.1.1 github.com/alecthomas/kong-yaml v0.1.1
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211101200231-0802afb9c160 github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211101200231-0802afb9c160
github.com/arangodb/go-driver v1.2.1 github.com/arangodb/go-driver v1.2.1
github.com/aws/aws-sdk-go v1.41.19 github.com/aws/aws-sdk-go v1.41.19
github.com/bits-and-blooms/bitset v1.2.1 // indirect
github.com/blevesearch/bleve/v2 v2.2.2 github.com/blevesearch/bleve/v2 v2.2.2
github.com/bmizerany/pat v0.0.0-20210406213842-e4b6760bdd6f // indirect
github.com/containerd/containerd v1.5.8 // indirect
github.com/coreos/go-oidc/v3 v3.1.0 github.com/coreos/go-oidc/v3 v3.1.0
github.com/docker/docker v17.12.0-ce-rc1.0.20201201034508-7d75c1d40d88+incompatible github.com/docker/docker v17.12.0-ce-rc1.0.20201201034508-7d75c1d40d88+incompatible
github.com/docker/go-connections v0.4.0 // indirect
github.com/eclipse/paho.mqtt.golang v1.3.5 // indirect
github.com/emitter-io/go/v2 v2.0.9 github.com/emitter-io/go/v2 v2.0.9
github.com/go-chi/chi v1.5.4 github.com/go-chi/chi v1.5.4
github.com/go-chi/cors v1.2.0 github.com/go-chi/cors v1.2.0
github.com/gobwas/ws v1.1.0 github.com/gobwas/ws v1.1.0
github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0 // indirect
github.com/iancoleman/strcase v0.2.0 github.com/iancoleman/strcase v0.2.0
github.com/icza/dyno v0.0.0-20210726202311-f1bafe5d9996 github.com/icza/dyno v0.0.0-20210726202311-f1bafe5d9996
github.com/imdario/mergo v0.3.12 github.com/imdario/mergo v0.3.12
github.com/kr/pretty v0.3.0 // indirect
github.com/mingrammer/commonregex v1.0.1 github.com/mingrammer/commonregex v1.0.1
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/tidwall/gjson v1.12.1 github.com/tidwall/gjson v1.12.1
github.com/tidwall/sjson v1.2.4 github.com/tidwall/sjson v1.2.4
github.com/tus/tusd v1.8.0 github.com/tus/tusd v1.8.0
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeipuuv/gojsonschema v1.2.0
go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20220318154914-8dddf5d87bd8
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect
golang.org/x/net v0.0.0-20211105192438-b53810dc28af // indirect
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c // indirect
golang.org/x/text v0.3.7 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
)
require (
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.5.1 // indirect
github.com/RoaringBitmap/roaring v0.9.4 // indirect
github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e // indirect
github.com/bits-and-blooms/bitset v1.2.1 // indirect
github.com/blevesearch/bleve_index_api v1.0.1 // indirect
github.com/blevesearch/go-porterstemmer v1.0.3 // indirect
github.com/blevesearch/mmap-go v1.0.3 // indirect
github.com/blevesearch/scorch_segment_api/v2 v2.1.0 // indirect
github.com/blevesearch/segment v0.9.0 // indirect
github.com/blevesearch/snowballstem v0.9.0 // indirect
github.com/blevesearch/upsidedown_store_api v1.0.1 // indirect
github.com/blevesearch/vellum v1.0.7 // indirect
github.com/blevesearch/zapx/v11 v11.3.1 // indirect
github.com/blevesearch/zapx/v12 v12.3.1 // indirect
github.com/blevesearch/zapx/v13 v13.3.1 // indirect
github.com/blevesearch/zapx/v14 v14.3.1 // indirect
github.com/blevesearch/zapx/v15 v15.3.1 // indirect
github.com/bmizerany/pat v0.0.0-20210406213842-e4b6760bdd6f // indirect
github.com/containerd/containerd v1.5.8 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/distribution v2.7.1+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/eclipse/paho.mqtt.golang v1.3.5 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/mschoch/smat v0.2.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/steveyen/gtreap v0.1.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
go.etcd.io/bbolt v1.3.6 // indirect
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
golang.org/x/net v0.0.0-20211105192438-b53810dc28af // indirect
golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20211021150943-2b146023228c // indirect
google.golang.org/grpc v1.41.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
) )
replace github.com/xeipuuv/gojsonschema => github.com/warjiang/gojsonschema v1.2.1-0.20210329105853-aa9f9a8cfec7 replace github.com/xeipuuv/gojsonschema => github.com/warjiang/gojsonschema v1.2.1-0.20210329105853-aa9f9a8cfec7
+4 -6
View File
@@ -276,7 +276,6 @@ github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8Nz
github.com/coreos/go-iptables v0.4.3/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-iptables v0.4.3/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU=
github.com/coreos/go-iptables v0.4.5/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-iptables v0.4.5/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU=
github.com/coreos/go-iptables v0.5.0/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-iptables v0.5.0/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU=
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw= github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw=
github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo=
@@ -833,9 +832,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -846,6 +844,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20220318154914-8dddf5d87bd8 h1:s/+U+w0teGzcoH2mdIlFQ6KfVKGaYpgyGdUefZrn9TU=
golang.org/x/exp v0.0.0-20220318154914-8dddf5d87bd8/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -1040,7 +1040,6 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -1058,9 +1057,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+8 -4
View File
@@ -11,8 +11,8 @@ import (
type Hooks struct { type Hooks struct {
DatabaseAfterConnectFuncs []func(ctx context.Context, client driver.Client, name string) DatabaseAfterConnectFuncs []func(ctx context.Context, client driver.Client, name string)
IngestionFilterFunc func(ctx context.Context, index *index.Index) (string, error) IngestionFilterFunc func(ctx context.Context, index *index.Index) (string, error)
TicketReadFilterFunc func(ctx context.Context) (string, map[string]interface{}, error) TicketReadFilterFunc func(ctx context.Context) (string, map[string]any, error)
TicketWriteFilterFunc func(ctx context.Context) (string, map[string]interface{}, error) TicketWriteFilterFunc func(ctx context.Context) (string, map[string]any, error)
GetGroupsFunc func(ctx context.Context, username string) ([]string, error) GetGroupsFunc func(ctx context.Context, username string) ([]string, error)
} }
@@ -26,20 +26,23 @@ func (h *Hooks) IngestionFilter(ctx context.Context, index *index.Index) (string
if h.IngestionFilterFunc != nil { if h.IngestionFilterFunc != nil {
return h.IngestionFilterFunc(ctx, index) return h.IngestionFilterFunc(ctx, index)
} }
return "[]", nil return "[]", nil
} }
func (h *Hooks) TicketReadFilter(ctx context.Context) (string, map[string]interface{}, error) { func (h *Hooks) TicketReadFilter(ctx context.Context) (string, map[string]any, error) {
if h.TicketReadFilterFunc != nil { if h.TicketReadFilterFunc != nil {
return h.TicketReadFilterFunc(ctx) return h.TicketReadFilterFunc(ctx)
} }
return "", nil, nil return "", nil, nil
} }
func (h *Hooks) TicketWriteFilter(ctx context.Context) (string, map[string]interface{}, error) { func (h *Hooks) TicketWriteFilter(ctx context.Context) (string, map[string]any, error) {
if h.TicketWriteFilterFunc != nil { if h.TicketWriteFilterFunc != nil {
return h.TicketWriteFilterFunc(ctx) return h.TicketWriteFilterFunc(ctx)
} }
return "", nil, nil return "", nil, nil
} }
@@ -47,5 +50,6 @@ func (h *Hooks) GetGroups(ctx context.Context, username string) ([]string, error
if h.GetGroupsFunc != nil { if h.GetGroupsFunc != nil {
return h.GetGroupsFunc(ctx, username) return h.GetGroupsFunc(ctx, username)
} }
return nil, nil return nil, nil
} }
+5 -2
View File
@@ -36,6 +36,7 @@ func (i *Index) Index(incidents []*model.TicketSimpleResponse) {
for _, incident := range incidents { for _, incident := range incidents {
if incident.ID == 0 { if incident.ID == 0 {
log.Println(errors.New("no ID"), incident) log.Println(errors.New("no ID"), incident)
continue continue
} }
@@ -44,8 +45,8 @@ func (i *Index) Index(incidents []*model.TicketSimpleResponse) {
log.Println(err) log.Println(err)
} }
} }
err := i.internal.Batch(b)
if err != nil { if err := i.internal.Batch(b); err != nil {
log.Println(err) log.Println(err)
} }
} }
@@ -59,6 +60,7 @@ func (i *Index) Search(term string) (ids []string, err error) {
for _, match := range result.Hits { for _, match := range result.Hits {
ids = append(ids, match.ID) ids = append(ids, match.ID)
} }
return ids, nil return ids, nil
} }
@@ -76,6 +78,7 @@ func (i *Index) Truncate() error {
return err return err
} }
i.internal = index i.internal = index
return nil return nil
} }
+11
View File
@@ -9,6 +9,8 @@ import (
) )
func TestIndex(t *testing.T) { func TestIndex(t *testing.T) {
t.Parallel()
type args struct { type args struct {
term string term string
} }
@@ -22,7 +24,10 @@ func TestIndex(t *testing.T) {
{name: "Not exists", args: args{"bar"}}, {name: "Not exists", args: args{"bar"}},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
i, cleanup, err := test.Index(t) i, cleanup, err := test.Index(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -37,6 +42,7 @@ func TestIndex(t *testing.T) {
gotIds, err := i.Search(tt.args.term) gotIds, err := i.Search(tt.args.term)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Search() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Search() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(gotIds, tt.wantIds) { if !reflect.DeepEqual(gotIds, tt.wantIds) {
@@ -47,6 +53,8 @@ func TestIndex(t *testing.T) {
} }
func TestIndex_Truncate(t *testing.T) { func TestIndex_Truncate(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
wantErr bool wantErr bool
@@ -54,7 +62,10 @@ func TestIndex_Truncate(t *testing.T) {
{name: "Truncate"}, {name: "Truncate"},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
i, cleanup, err := test.Index(t) i, cleanup, err := test.Index(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
+13 -5
View File
@@ -29,11 +29,13 @@ func restoreHandler(catalystStorage *storage.Storage, db *database.Database, c *
uf, header, err := r.FormFile("backup") uf, header, err := r.FormFile("backup")
if err != nil { if err != nil {
api.JSONError(w, err) api.JSONError(w, err)
return return
} }
if err = Restore(r.Context(), catalystStorage, db, c, uf, header.Size); err != nil { if err = Restore(r.Context(), catalystStorage, db, c, uf, header.Size); err != nil {
api.JSONError(w, err) api.JSONError(w, err)
return return
} }
} }
@@ -52,7 +54,7 @@ func Restore(ctx context.Context, catalystStorage *storage.Storage, db *database
} }
if fsys.Comment != GetVersion() { if fsys.Comment != GetVersion() {
return errors.New(fmt.Sprintf("wrong version, got: %s, want: %s", fsys.Comment, GetVersion())) return fmt.Errorf("wrong version, got: %s, want: %s", fsys.Comment, GetVersion())
} }
dir, err := os.MkdirTemp("", "catalyst-restore") dir, err := os.MkdirTemp("", "catalyst-restore")
@@ -89,17 +91,19 @@ func restoreS3(catalystStorage *storage.Storage, p string) error {
return err return err
} }
} }
return nil return nil
} }
func restoreBucket(catalystStorage *storage.Storage, entry fs.DirEntry, minioDir fs.FS) error { func restoreBucket(catalystStorage *storage.Storage, entry fs.DirEntry, minioDir fs.FS) error {
_, err := catalystStorage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String(entry.Name())}) _, err := catalystStorage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String(entry.Name())})
if err != nil { if err != nil {
awsError, ok := err.(awserr.Error) var awsError awserr.Error
if !ok || (awsError.Code() != s3.ErrCodeBucketAlreadyExists && awsError.Code() != s3.ErrCodeBucketAlreadyOwnedByYou) { if errors.As(err, &awsError) && (awsError.Code() == s3.ErrCodeBucketAlreadyExists || awsError.Code() == s3.ErrCodeBucketAlreadyOwnedByYou) {
return err return nil
} }
return nil
return err
} }
uploader := catalystStorage.Uploader() uploader := catalystStorage.Uploader()
@@ -115,11 +119,13 @@ func restoreBucket(catalystStorage *storage.Storage, entry fs.DirEntry, minioDir
return nil return nil
} }
_, err = uploader.Upload(&s3manager.UploadInput{Body: f, Bucket: pointer.String(entry.Name()), Key: pointer.String(path)}) _, err = uploader.Upload(&s3manager.UploadInput{Body: f, Bucket: pointer.String(entry.Name()), Key: pointer.String(path)})
return err return err
}) })
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
@@ -131,6 +137,7 @@ func unzip(archive *zip.Reader, dir string) error {
if d.IsDir() { if d.IsDir() {
_ = os.MkdirAll(path.Join(dir, p), os.ModePerm) _ = os.MkdirAll(path.Join(dir, p), os.ModePerm)
return nil return nil
} }
@@ -163,5 +170,6 @@ func arangorestore(dir string, config *database.Config) error {
"--server.database", name, "--server.database", name,
} }
cmd := exec.Command("arangorestore", args...) cmd := exec.Command("arangorestore", args...)
return cmd.Run() return cmd.Run()
} }
+11 -23
View File
@@ -5,6 +5,8 @@ import (
"sort" "sort"
"strings" "strings"
"golang.org/x/exp/slices"
"github.com/SecurityBrewery/catalyst/generated/model" "github.com/SecurityBrewery/catalyst/generated/model"
) )
@@ -60,23 +62,16 @@ func UserHasRoles(user *model.UserResponse, roles []Role) bool {
for _, role := range roles { for _, role := range roles {
if !UserHasRole(user, role) { if !UserHasRole(user, role) {
hasRoles = false hasRoles = false
break break
} }
} }
return hasRoles return hasRoles
} }
func UserHasRole(user *model.UserResponse, role Role) bool { func UserHasRole(user *model.UserResponse, role Role) bool {
return ContainsRole(FromStrings(user.Roles), role) return slices.Contains(FromStrings(user.Roles), role)
}
func ContainsRole(roles []Role, role Role) bool {
for _, r := range roles {
if r.String() == role.String() { // || strings.HasPrefix(role.String(), r.String()+":")
return true
}
}
return false
} }
func Explodes(s []string) []Role { func Explodes(s []string) []Role {
@@ -84,10 +79,10 @@ func Explodes(s []string) []Role {
for _, e := range s { for _, e := range s {
roles = append(roles, Explode(e)...) roles = append(roles, Explode(e)...)
} }
roles = unique(roles)
sort.Slice(roles, func(i, j int) bool { sort.Slice(roles, func(i, j int) bool {
return roles[i].String() < roles[j].String() return roles[i].String() < roles[j].String()
}) })
roles = slices.Compact(roles)
return roles return roles
} }
@@ -98,12 +93,15 @@ func Explode(s string) []Role {
switch s { switch s {
case Admin: case Admin:
roles = append(roles, listPrefix(Admin)...) roles = append(roles, listPrefix(Admin)...)
fallthrough fallthrough
case Engineer: case Engineer:
roles = append(roles, listPrefix(Engineer)...) roles = append(roles, listPrefix(Engineer)...)
fallthrough fallthrough
case Analyst: case Analyst:
roles = append(roles, listPrefix(Analyst)...) roles = append(roles, listPrefix(Analyst)...)
return roles return roles
} }
@@ -128,18 +126,6 @@ func listPrefix(s string) []Role {
return roles return roles
} }
func unique(l []Role) []Role {
keys := make(map[Role]bool)
var list []Role
for _, entry := range l {
if _, value := keys[entry]; !value {
keys[entry] = true
list = append(list, entry)
}
}
return list
}
func List() []Role { func List() []Role {
return []Role{ return []Role{
AutomationRead, CurrentuserdataRead, CurrentuserdataWrite, AutomationRead, CurrentuserdataRead, CurrentuserdataWrite,
@@ -167,6 +153,7 @@ func Strings(roles []Role) []string {
for _, role := range roles { for _, role := range roles {
s = append(s, role.String()) s = append(s, role.String())
} }
return s return s
} }
@@ -179,5 +166,6 @@ func FromStrings(s []string) []Role {
} }
roles = append(roles, role) roles = append(roles, role)
} }
return roles return roles
} }
+4
View File
@@ -14,6 +14,7 @@ func automationResponseID(automation *model.AutomationResponse) []driver.Documen
if automation == nil { if automation == nil {
return nil return nil
} }
return automationID(automation.ID) return automationID(automation.ID)
} }
@@ -27,6 +28,7 @@ func (s *Service) ListAutomations(ctx context.Context) ([]*model.AutomationRespo
func (s *Service) CreateAutomation(ctx context.Context, form *model.AutomationForm) (doc *model.AutomationResponse, err error) { func (s *Service) CreateAutomation(ctx context.Context, form *model.AutomationForm) (doc *model.AutomationResponse, err error) {
defer s.publishRequest(ctx, err, "CreateAutomation", automationResponseID(doc)) defer s.publishRequest(ctx, err, "CreateAutomation", automationResponseID(doc))
return s.database.AutomationCreate(ctx, form) return s.database.AutomationCreate(ctx, form)
} }
@@ -36,10 +38,12 @@ func (s *Service) GetAutomation(ctx context.Context, id string) (*model.Automati
func (s *Service) UpdateAutomation(ctx context.Context, id string, form *model.AutomationForm) (doc *model.AutomationResponse, err error) { func (s *Service) UpdateAutomation(ctx context.Context, id string, form *model.AutomationForm) (doc *model.AutomationResponse, err error) {
defer s.publishRequest(ctx, err, "UpdateAutomation", automationResponseID(doc)) defer s.publishRequest(ctx, err, "UpdateAutomation", automationResponseID(doc))
return s.database.AutomationUpdate(ctx, id, form) return s.database.AutomationUpdate(ctx, id, form)
} }
func (s *Service) DeleteAutomation(ctx context.Context, id string) (err error) { func (s *Service) DeleteAutomation(ctx context.Context, id string) (err error) {
defer s.publishRequest(ctx, err, "DeleteAutomation", automationID(id)) defer s.publishRequest(ctx, err, "DeleteAutomation", automationID(id))
return s.database.AutomationDelete(ctx, id) return s.database.AutomationDelete(ctx, id)
} }
+5 -1
View File
@@ -14,6 +14,7 @@ func dashboardResponseID(doc *model.DashboardResponse) []driver.DocumentID {
if doc == nil { if doc == nil {
return nil return nil
} }
return templateID(doc.ID) return templateID(doc.ID)
} }
@@ -27,6 +28,7 @@ func (s *Service) ListDashboards(ctx context.Context) ([]*model.DashboardRespons
func (s *Service) CreateDashboard(ctx context.Context, dashboard *model.Dashboard) (doc *model.DashboardResponse, err error) { func (s *Service) CreateDashboard(ctx context.Context, dashboard *model.Dashboard) (doc *model.DashboardResponse, err error) {
defer s.publishRequest(ctx, err, "CreateDashboard", dashboardResponseID(doc)) defer s.publishRequest(ctx, err, "CreateDashboard", dashboardResponseID(doc))
return s.database.DashboardCreate(ctx, dashboard) return s.database.DashboardCreate(ctx, dashboard)
} }
@@ -36,14 +38,16 @@ func (s *Service) GetDashboard(ctx context.Context, id string) (*model.Dashboard
func (s *Service) UpdateDashboard(ctx context.Context, id string, form *model.Dashboard) (doc *model.DashboardResponse, err error) { func (s *Service) UpdateDashboard(ctx context.Context, id string, form *model.Dashboard) (doc *model.DashboardResponse, err error) {
defer s.publishRequest(ctx, err, "UpdateDashboard", dashboardResponseID(doc)) defer s.publishRequest(ctx, err, "UpdateDashboard", dashboardResponseID(doc))
return s.database.DashboardUpdate(ctx, id, form) return s.database.DashboardUpdate(ctx, id, form)
} }
func (s *Service) DeleteDashboard(ctx context.Context, id string) (err error) { func (s *Service) DeleteDashboard(ctx context.Context, id string) (err error) {
defer s.publishRequest(ctx, err, "DeleteDashboard", dashboardID(id)) defer s.publishRequest(ctx, err, "DeleteDashboard", dashboardID(id))
return s.database.DashboardDelete(ctx, id) return s.database.DashboardDelete(ctx, id)
} }
func (s *Service) DashboardData(ctx context.Context, aggregation string, filter *string) (map[string]interface{}, error) { func (s *Service) DashboardData(ctx context.Context, aggregation string, filter *string) (map[string]any, error) {
return s.database.WidgetData(ctx, aggregation, filter) return s.database.WidgetData(ctx, aggregation, filter)
} }
+2
View File
@@ -15,6 +15,7 @@ func jobResponseID(job *model.JobResponse) []driver.DocumentID {
if job == nil { if job == nil {
return nil return nil
} }
return jobID(job.ID) return jobID(job.ID)
} }
@@ -48,5 +49,6 @@ func (s *Service) GetJob(ctx context.Context, id string) (*model.JobResponse, er
func (s *Service) UpdateJob(ctx context.Context, id string, job *model.JobUpdate) (doc *model.JobResponse, err error) { func (s *Service) UpdateJob(ctx context.Context, id string, job *model.JobUpdate) (doc *model.JobResponse, err error) {
defer s.publishRequest(ctx, err, "UpdateJob", jobResponseID(doc)) defer s.publishRequest(ctx, err, "UpdateJob", jobResponseID(doc))
return s.database.JobUpdate(ctx, id, job) return s.database.JobUpdate(ctx, id, job)
} }
+1
View File
@@ -9,5 +9,6 @@ import (
func (s *Service) GetLogs(ctx context.Context, reference string) ([]*model.LogEntry, error) { func (s *Service) GetLogs(ctx context.Context, reference string) ([]*model.LogEntry, error) {
id, _ := url.QueryUnescape(reference) id, _ := url.QueryUnescape(reference)
return s.database.LogList(ctx, id) return s.database.LogList(ctx, id)
} }
+4
View File
@@ -14,6 +14,7 @@ func playbookResponseID(playbook *model.PlaybookTemplateResponse) []driver.Docum
if playbook == nil { if playbook == nil {
return nil return nil
} }
return playbookID(playbook.ID) return playbookID(playbook.ID)
} }
@@ -27,6 +28,7 @@ func (s *Service) ListPlaybooks(ctx context.Context) ([]*model.PlaybookTemplateR
func (s *Service) CreatePlaybook(ctx context.Context, form *model.PlaybookTemplateForm) (doc *model.PlaybookTemplateResponse, err error) { func (s *Service) CreatePlaybook(ctx context.Context, form *model.PlaybookTemplateForm) (doc *model.PlaybookTemplateResponse, err error) {
defer s.publishRequest(ctx, err, "CreatePlaybook", playbookResponseID(doc)) defer s.publishRequest(ctx, err, "CreatePlaybook", playbookResponseID(doc))
return s.database.PlaybookCreate(ctx, form) return s.database.PlaybookCreate(ctx, form)
} }
@@ -36,10 +38,12 @@ func (s *Service) GetPlaybook(ctx context.Context, id string) (*model.PlaybookTe
func (s *Service) UpdatePlaybook(ctx context.Context, id string, form *model.PlaybookTemplateForm) (doc *model.PlaybookTemplateResponse, err error) { func (s *Service) UpdatePlaybook(ctx context.Context, id string, form *model.PlaybookTemplateForm) (doc *model.PlaybookTemplateResponse, err error) {
defer s.publishRequest(ctx, err, "UpdatePlaybook", playbookResponseID(doc)) defer s.publishRequest(ctx, err, "UpdatePlaybook", playbookResponseID(doc))
return s.database.PlaybookUpdate(ctx, id, form) return s.database.PlaybookUpdate(ctx, id, form)
} }
func (s *Service) DeletePlaybook(ctx context.Context, id string) (err error) { func (s *Service) DeletePlaybook(ctx context.Context, id string) (err error) {
defer s.publishRequest(ctx, err, "DeletePlaybook", playbookID(id)) defer s.publishRequest(ctx, err, "DeletePlaybook", playbookID(id))
return s.database.PlaybookDelete(ctx, id) return s.database.PlaybookDelete(ctx, id)
} }
+6 -1
View File
@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"log"
"github.com/arangodb/go-driver" "github.com/arangodb/go-driver"
@@ -33,6 +34,10 @@ func (s *Service) publishRequest(ctx context.Context, err error, function string
userID = user.ID userID = user.ID
} }
go s.bus.PublishRequest(userID, function, ids) go func() {
if err := s.bus.PublishRequest(userID, function, ids); err != nil {
log.Println(err)
}
}()
} }
} }
+4
View File
@@ -14,6 +14,7 @@ func templateResponseID(template *model.TicketTemplateResponse) []driver.Documen
if template == nil { if template == nil {
return nil return nil
} }
return templateID(template.ID) return templateID(template.ID)
} }
@@ -27,6 +28,7 @@ func (s *Service) ListTemplates(ctx context.Context) ([]*model.TicketTemplateRes
func (s *Service) CreateTemplate(ctx context.Context, form *model.TicketTemplateForm) (doc *model.TicketTemplateResponse, err error) { func (s *Service) CreateTemplate(ctx context.Context, form *model.TicketTemplateForm) (doc *model.TicketTemplateResponse, err error) {
defer s.publishRequest(ctx, err, "CreateTemplate", templateResponseID(doc)) defer s.publishRequest(ctx, err, "CreateTemplate", templateResponseID(doc))
return s.database.TemplateCreate(ctx, form) return s.database.TemplateCreate(ctx, form)
} }
@@ -36,10 +38,12 @@ func (s *Service) GetTemplate(ctx context.Context, id string) (*model.TicketTemp
func (s *Service) UpdateTemplate(ctx context.Context, id string, form *model.TicketTemplateForm) (doc *model.TicketTemplateResponse, err error) { func (s *Service) UpdateTemplate(ctx context.Context, id string, form *model.TicketTemplateForm) (doc *model.TicketTemplateResponse, err error) {
defer s.publishRequest(ctx, err, "UpdateTemplate", templateResponseID(doc)) defer s.publishRequest(ctx, err, "UpdateTemplate", templateResponseID(doc))
return s.database.TemplateUpdate(ctx, id, form) return s.database.TemplateUpdate(ctx, id, form)
} }
func (s *Service) DeleteTemplate(ctx context.Context, id string) (err error) { func (s *Service) DeleteTemplate(ctx context.Context, id string) (err error) {
defer s.publishRequest(ctx, err, "DeleteTemplate", templateID(id)) defer s.publishRequest(ctx, err, "DeleteTemplate", templateID(id))
return s.database.TemplateDelete(ctx, id) return s.database.TemplateDelete(ctx, id)
} }
+24 -2
View File
@@ -18,11 +18,13 @@ func ticketWithTicketsID(ticketResponse *model.TicketWithTickets) []driver.Docum
if ticketResponse == nil { if ticketResponse == nil {
return nil return nil
} }
return ticketID(ticketResponse.ID) return ticketID(ticketResponse.ID)
} }
func ticketID(ticketID int64) []driver.DocumentID { func ticketID(ticketID int64) []driver.DocumentID {
id := fmt.Sprintf("%s/%d", database.TicketCollectionName, ticketID) id := fmt.Sprintf("%s/%d", database.TicketCollectionName, ticketID)
return []driver.DocumentID{driver.DocumentID(id)} return []driver.DocumentID{driver.DocumentID(id)}
} }
@@ -31,6 +33,7 @@ func ticketIDs(ticketResponses []*model.TicketResponse) []driver.DocumentID {
for _, ticketResponse := range ticketResponses { for _, ticketResponse := range ticketResponses {
ids = append(ids, ticketID(ticketResponse.ID)...) ids = append(ids, ticketID(ticketResponse.ID)...)
} }
return ids return ids
} }
@@ -63,6 +66,7 @@ func (s *Service) CreateTicket(ctx context.Context, form *model.TicketForm) (doc
if len(createdTickets) > 0 { if len(createdTickets) > 0 {
return createdTickets[0], err return createdTickets[0], err
} }
return nil, err return nil, err
} }
@@ -72,6 +76,7 @@ func (s *Service) CreateTicketBatch(ctx context.Context, ticketFormArray *model.
} }
createdTickets, err := s.database.TicketBatchCreate(ctx, *ticketFormArray) createdTickets, err := s.database.TicketBatchCreate(ctx, *ticketFormArray)
defer s.publishRequest(ctx, err, "CreateTicket", ticketIDs(createdTickets)) defer s.publishRequest(ctx, err, "CreateTicket", ticketIDs(createdTickets))
return err return err
} }
@@ -81,16 +86,19 @@ func (s *Service) GetTicket(ctx context.Context, i int64) (*model.TicketWithTick
func (s *Service) UpdateTicket(ctx context.Context, i int64, ticket *model.Ticket) (doc *model.TicketWithTickets, err error) { func (s *Service) UpdateTicket(ctx context.Context, i int64, ticket *model.Ticket) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "UpdateTicket", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "UpdateTicket", ticketWithTicketsID(doc))
return s.database.TicketUpdate(ctx, i, ticket) return s.database.TicketUpdate(ctx, i, ticket)
} }
func (s *Service) DeleteTicket(ctx context.Context, i int64) (err error) { func (s *Service) DeleteTicket(ctx context.Context, i int64) (err error) {
defer s.publishRequest(ctx, err, "DeleteTicket", ticketID(i)) defer s.publishRequest(ctx, err, "DeleteTicket", ticketID(i))
return s.database.TicketDelete(ctx, i) return s.database.TicketDelete(ctx, i)
} }
func (s *Service) AddArtifact(ctx context.Context, i int64, artifact *model.Artifact) (doc *model.TicketWithTickets, err error) { func (s *Service) AddArtifact(ctx context.Context, i int64, artifact *model.Artifact) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "AddArtifact", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "AddArtifact", ticketWithTicketsID(doc))
return s.database.AddArtifact(ctx, i, artifact) return s.database.AddArtifact(ctx, i, artifact)
} }
@@ -100,16 +108,19 @@ func (s *Service) GetArtifact(ctx context.Context, i int64, s2 string) (*model.A
func (s *Service) SetArtifact(ctx context.Context, i int64, s2 string, artifact *model.Artifact) (doc *model.TicketWithTickets, err error) { func (s *Service) SetArtifact(ctx context.Context, i int64, s2 string, artifact *model.Artifact) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "SetArtifact", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "SetArtifact", ticketWithTicketsID(doc))
return s.database.ArtifactUpdate(ctx, i, s2, artifact) return s.database.ArtifactUpdate(ctx, i, s2, artifact)
} }
func (s *Service) RemoveArtifact(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) { func (s *Service) RemoveArtifact(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "RemoveArtifact", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "RemoveArtifact", ticketWithTicketsID(doc))
return s.database.RemoveArtifact(ctx, i, s2) return s.database.RemoveArtifact(ctx, i, s2)
} }
func (s *Service) EnrichArtifact(ctx context.Context, i int64, s2 string, form *model.EnrichmentForm) (doc *model.TicketWithTickets, err error) { func (s *Service) EnrichArtifact(ctx context.Context, i int64, s2 string, form *model.EnrichmentForm) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "EnrichArtifact", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "EnrichArtifact", ticketWithTicketsID(doc))
return s.database.EnrichArtifact(ctx, i, s2, form) return s.database.EnrichArtifact(ctx, i, s2, form)
} }
@@ -123,46 +134,55 @@ func (s *Service) RunArtifact(ctx context.Context, id int64, name string, automa
jobID := uuid.NewString() jobID := uuid.NewString()
origin := &model.Origin{ArtifactOrigin: &model.ArtifactOrigin{TicketId: id, Artifact: name}} origin := &model.Origin{ArtifactOrigin: &model.ArtifactOrigin{TicketId: id, Artifact: name}}
return s.bus.PublishJob(jobID, automation, map[string]string{"default": name}, &model.Context{Artifact: artifact}, origin) return s.bus.PublishJob(jobID, automation, map[string]string{"default": name}, &model.Context{Artifact: artifact}, origin)
} }
func (s *Service) AddComment(ctx context.Context, i int64, form *model.CommentForm) (doc *model.TicketWithTickets, err error) { func (s *Service) AddComment(ctx context.Context, i int64, form *model.CommentForm) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "AddComment", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "AddComment", ticketWithTicketsID(doc))
return s.database.AddComment(ctx, i, form) return s.database.AddComment(ctx, i, form)
} }
func (s *Service) RemoveComment(ctx context.Context, i int64, i2 int) (doc *model.TicketWithTickets, err error) { func (s *Service) RemoveComment(ctx context.Context, i int64, i2 int) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "RemoveComment", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "RemoveComment", ticketWithTicketsID(doc))
return s.database.RemoveComment(ctx, i, int64(i2)) return s.database.RemoveComment(ctx, i, int64(i2))
} }
func (s *Service) AddTicketPlaybook(ctx context.Context, i int64, form *model.PlaybookTemplateForm) (doc *model.TicketWithTickets, err error) { func (s *Service) AddTicketPlaybook(ctx context.Context, i int64, form *model.PlaybookTemplateForm) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "AddTicketPlaybook", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "AddTicketPlaybook", ticketWithTicketsID(doc))
return s.database.AddTicketPlaybook(ctx, i, form) return s.database.AddTicketPlaybook(ctx, i, form)
} }
func (s *Service) RemoveTicketPlaybook(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) { func (s *Service) RemoveTicketPlaybook(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "RemoveTicketPlaybook", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "RemoveTicketPlaybook", ticketWithTicketsID(doc))
return s.database.RemoveTicketPlaybook(ctx, i, s2) return s.database.RemoveTicketPlaybook(ctx, i, s2)
} }
func (s *Service) SetTaskData(ctx context.Context, i int64, s3 string, s2 string, data map[string]interface{}) (doc *model.TicketWithTickets, err error) { func (s *Service) SetTaskData(ctx context.Context, i int64, s3 string, s2 string, data map[string]any) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "SetTask", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "SetTask", ticketWithTicketsID(doc))
return s.database.TaskUpdateData(ctx, i, s3, s2, data) return s.database.TaskUpdateData(ctx, i, s3, s2, data)
} }
func (s *Service) SetTaskOwner(ctx context.Context, i int64, s3 string, s2 string, owner string) (doc *model.TicketWithTickets, err error) { func (s *Service) SetTaskOwner(ctx context.Context, i int64, s3 string, s2 string, owner string) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "SetTask", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "SetTask", ticketWithTicketsID(doc))
return s.database.TaskUpdateOwner(ctx, i, s3, s2, owner) return s.database.TaskUpdateOwner(ctx, i, s3, s2, owner)
} }
func (s *Service) CompleteTask(ctx context.Context, i int64, s3 string, s2 string, m map[string]interface{}) (doc *model.TicketWithTickets, err error) { func (s *Service) CompleteTask(ctx context.Context, i int64, s3 string, s2 string, m map[string]any) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "CompleteTask", ticketWithTicketsID(doc)) defer s.publishRequest(ctx, err, "CompleteTask", ticketWithTicketsID(doc))
return s.database.TaskComplete(ctx, i, s3, s2, m) return s.database.TaskComplete(ctx, i, s3, s2, m)
} }
func (s *Service) RunTask(ctx context.Context, i int64, s3 string, s2 string) (err error) { func (s *Service) RunTask(ctx context.Context, i int64, s3 string, s2 string) (err error) {
defer s.publishRequest(ctx, err, "RunTask", ticketID(i)) defer s.publishRequest(ctx, err, "RunTask", ticketID(i))
return s.database.TaskRun(ctx, i, s3, s2) return s.database.TaskRun(ctx, i, s3, s2)
} }
@@ -171,11 +191,13 @@ func (s *Service) SetReferences(ctx context.Context, i int64, references *model.
return nil, &api.HTTPError{Status: http.StatusUnprocessableEntity, Internal: errors.New("no references given")} return nil, &api.HTTPError{Status: http.StatusUnprocessableEntity, Internal: errors.New("no references given")}
} }
defer s.publishRequest(ctx, err, "SetReferences", ticketID(i)) defer s.publishRequest(ctx, err, "SetReferences", ticketID(i))
return s.database.SetReferences(ctx, i, *references) return s.database.SetReferences(ctx, i, *references)
} }
func (s *Service) SetSchema(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) { func (s *Service) SetSchema(ctx context.Context, i int64, s2 string) (doc *model.TicketWithTickets, err error) {
defer s.publishRequest(ctx, err, "SetSchema", ticketID(i)) defer s.publishRequest(ctx, err, "SetSchema", ticketID(i))
return s.database.SetTemplate(ctx, i, s2) return s.database.SetTemplate(ctx, i, s2)
} }
+4
View File
@@ -14,6 +14,7 @@ func ticketTypeResponseID(ticketType *model.TicketTypeResponse) []driver.Documen
if ticketType == nil { if ticketType == nil {
return nil return nil
} }
return userDataID(ticketType.ID) return userDataID(ticketType.ID)
} }
@@ -27,6 +28,7 @@ func (s *Service) ListTicketTypes(ctx context.Context) ([]*model.TicketTypeRespo
func (s *Service) CreateTicketType(ctx context.Context, form *model.TicketTypeForm) (doc *model.TicketTypeResponse, err error) { func (s *Service) CreateTicketType(ctx context.Context, form *model.TicketTypeForm) (doc *model.TicketTypeResponse, err error) {
defer s.publishRequest(ctx, err, "CreateTicketType", ticketTypeResponseID(doc)) defer s.publishRequest(ctx, err, "CreateTicketType", ticketTypeResponseID(doc))
return s.database.TicketTypeCreate(ctx, form) return s.database.TicketTypeCreate(ctx, form)
} }
@@ -36,10 +38,12 @@ func (s *Service) GetTicketType(ctx context.Context, id string) (*model.TicketTy
func (s *Service) UpdateTicketType(ctx context.Context, id string, form *model.TicketTypeForm) (doc *model.TicketTypeResponse, err error) { func (s *Service) UpdateTicketType(ctx context.Context, id string, form *model.TicketTypeForm) (doc *model.TicketTypeResponse, err error) {
defer s.publishRequest(ctx, err, "UpdateTicketType", ticketTypeResponseID(doc)) defer s.publishRequest(ctx, err, "UpdateTicketType", ticketTypeResponseID(doc))
return s.database.TicketTypeUpdate(ctx, id, form) return s.database.TicketTypeUpdate(ctx, id, form)
} }
func (s *Service) DeleteTicketType(ctx context.Context, id string) (err error) { func (s *Service) DeleteTicketType(ctx context.Context, id string) (err error) {
defer s.publishRequest(ctx, err, "DeleteTicketType", ticketTypeID(id)) defer s.publishRequest(ctx, err, "DeleteTicketType", ticketTypeID(id))
return s.database.TicketTypeDelete(ctx, id) return s.database.TicketTypeDelete(ctx, id)
} }
+6
View File
@@ -16,6 +16,7 @@ func newUserResponseID(user *model.NewUserResponse) []driver.DocumentID {
if user == nil { if user == nil {
return nil return nil
} }
return userID(user.ID) return userID(user.ID)
} }
@@ -23,6 +24,7 @@ func userResponseID(user *model.UserResponse) []driver.DocumentID {
if user == nil { if user == nil {
return nil return nil
} }
return userID(user.ID) return userID(user.ID)
} }
@@ -36,6 +38,7 @@ func (s *Service) ListUsers(ctx context.Context) ([]*model.UserResponse, error)
func (s *Service) CreateUser(ctx context.Context, form *model.UserForm) (doc *model.NewUserResponse, err error) { func (s *Service) CreateUser(ctx context.Context, form *model.UserForm) (doc *model.NewUserResponse, err error) {
defer s.publishRequest(ctx, err, "CreateUser", newUserResponseID(doc)) defer s.publishRequest(ctx, err, "CreateUser", newUserResponseID(doc))
return s.database.UserCreate(ctx, form) return s.database.UserCreate(ctx, form)
} }
@@ -45,11 +48,13 @@ func (s *Service) GetUser(ctx context.Context, s2 string) (*model.UserResponse,
func (s *Service) UpdateUser(ctx context.Context, s2 string, form *model.UserForm) (doc *model.UserResponse, err error) { func (s *Service) UpdateUser(ctx context.Context, s2 string, form *model.UserForm) (doc *model.UserResponse, err error) {
defer s.publishRequest(ctx, err, "UpdateUser", userID(s2)) defer s.publishRequest(ctx, err, "UpdateUser", userID(s2))
return s.database.UserUpdate(ctx, s2, form) return s.database.UserUpdate(ctx, s2, form)
} }
func (s *Service) DeleteUser(ctx context.Context, s2 string) (err error) { func (s *Service) DeleteUser(ctx context.Context, s2 string) (err error) {
defer s.publishRequest(ctx, err, "DeleteUser", userID(s2)) defer s.publishRequest(ctx, err, "DeleteUser", userID(s2))
return s.database.UserDelete(ctx, s2) return s.database.UserDelete(ctx, s2)
} }
@@ -59,5 +64,6 @@ func (s *Service) CurrentUser(ctx context.Context) (*model.UserResponse, error)
return nil, errors.New("no user in context") return nil, errors.New("no user in context")
} }
s.publishRequest(ctx, nil, "CurrentUser", userResponseID(user)) s.publishRequest(ctx, nil, "CurrentUser", userResponseID(user))
return user, nil return user, nil
} }
+3
View File
@@ -16,6 +16,7 @@ func userDataResponseID(userData *model.UserDataResponse) []driver.DocumentID {
if userData == nil { if userData == nil {
return nil return nil
} }
return userDataID(userData.ID) return userDataID(userData.ID)
} }
@@ -33,6 +34,7 @@ func (s *Service) GetUserData(ctx context.Context, id string) (*model.UserDataRe
func (s *Service) UpdateUserData(ctx context.Context, id string, data *model.UserData) (doc *model.UserDataResponse, err error) { func (s *Service) UpdateUserData(ctx context.Context, id string, data *model.UserData) (doc *model.UserDataResponse, err error) {
defer s.publishRequest(ctx, err, "CreateUser", userDataResponseID(doc)) defer s.publishRequest(ctx, err, "CreateUser", userDataResponseID(doc))
return s.database.UserDataUpdate(ctx, id, data) return s.database.UserDataUpdate(ctx, id, data)
} }
@@ -52,5 +54,6 @@ func (s *Service) UpdateCurrentUserData(ctx context.Context, data *model.UserDat
} }
defer s.publishRequest(ctx, err, "UpdateCurrentUserData", userDataResponseID(doc)) defer s.publishRequest(ctx, err, "UpdateCurrentUserData", userDataResponseID(doc))
return s.database.UserDataUpdate(ctx, user.ID, data) return s.database.UserDataUpdate(ctx, user.ID, data)
} }
+12 -4
View File
@@ -1,6 +1,8 @@
package storage package storage
import ( import (
"errors"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
@@ -29,6 +31,7 @@ func New(config *Config) (*Storage, error) {
DisableSSL: aws.Bool(true), DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true), S3ForcePathStyle: aws.Bool(true),
}) })
return &Storage{s}, err return &Storage{s}, err
} }
@@ -39,17 +42,20 @@ func (s *Storage) S3() *s3.S3 {
func (s *Storage) Downloader() *s3manager.Downloader { func (s *Storage) Downloader() *s3manager.Downloader {
d := s3manager.NewDownloader(s.session) d := s3manager.NewDownloader(s.session)
d.Concurrency = 1 d.Concurrency = 1
return d return d
} }
func (s *Storage) Uploader() *s3manager.Uploader { func (s *Storage) Uploader() *s3manager.Uploader {
d := s3manager.NewUploader(s.session) d := s3manager.NewUploader(s.session)
d.Concurrency = 1 d.Concurrency = 1
return d return d
} }
func (s *Storage) DeleteBucket(name string) error { func (s *Storage) DeleteBucket(name string) error {
_, err := s.S3().DeleteBucket(&s3.DeleteBucketInput{Bucket: pointer.String("catalyst-" + name)}) _, err := s.S3().DeleteBucket(&s3.DeleteBucketInput{Bucket: pointer.String("catalyst-" + name)})
return err return err
} }
@@ -61,11 +67,13 @@ func CreateBucket(client *s3.S3, ticketID string) error {
return err return err
} }
} else { } else {
awsError, ok := err.(awserr.Error) var awsError awserr.Error
if !ok || (awsError.Code() != s3.ErrCodeBucketAlreadyExists && awsError.Code() != s3.ErrCodeBucketAlreadyOwnedByYou) { if errors.As(err, &awsError) && (awsError.Code() == s3.ErrCodeBucketAlreadyExists || awsError.Code() == s3.ErrCodeBucketAlreadyOwnedByYou) {
return err return nil
} }
return nil
return err
} }
return err return err
} }
+24 -2
View File
@@ -25,6 +25,8 @@ import (
) )
func TestBackupAndRestore(t *testing.T) { func TestBackupAndRestore(t *testing.T) {
t.Parallel()
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
if runtime.GOARCH == "arm64" { if runtime.GOARCH == "arm64" {
@@ -41,7 +43,10 @@ func TestBackupAndRestore(t *testing.T) {
{name: "Backup", want: want{status: http.StatusOK}}, {name: "Backup", want: want{status: http.StatusOK}},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, _, server, err := Catalyst(t) ctx, _, server, err := Catalyst(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -75,6 +80,8 @@ func TestBackupAndRestore(t *testing.T) {
} }
func assertBackup(t *testing.T, server *catalyst.Server) []byte { func assertBackup(t *testing.T, server *catalyst.Server) []byte {
t.Helper()
// setup request // setup request
req := httptest.NewRequest(http.MethodGet, "/api/backup/create", nil) req := httptest.NewRequest(http.MethodGet, "/api/backup/create", nil)
req.Header.Set("PRIVATE-TOKEN", "test") req.Header.Set("PRIVATE-TOKEN", "test")
@@ -97,6 +104,8 @@ func assertBackup(t *testing.T, server *catalyst.Server) []byte {
} }
func assertZipFile(t *testing.T, r *zip.Reader) { func assertZipFile(t *testing.T, r *zip.Reader) {
t.Helper()
var names []string var names []string
for _, f := range r.File { for _, f := range r.File {
names = append(names, f.Name) names = append(names, f.Name)
@@ -120,9 +129,11 @@ func clearAllDatabases(server *catalyst.Server) {
} }
func deleteAllBuckets(t *testing.T, server *catalyst.Server) { func deleteAllBuckets(t *testing.T, server *catalyst.Server) {
t.Helper()
buckets, err := server.Storage.S3().ListBuckets(&s3.ListBucketsInput{}) buckets, err := server.Storage.S3().ListBuckets(&s3.ListBucketsInput{})
for _, bucket := range buckets.Buckets { for _, bucket := range buckets.Buckets {
server.Storage.S3().DeleteBucket(&s3.DeleteBucketInput{ _, _ = server.Storage.S3().DeleteBucket(&s3.DeleteBucketInput{
Bucket: bucket.Name, Bucket: bucket.Name,
}) })
} }
@@ -133,6 +144,8 @@ func deleteAllBuckets(t *testing.T, server *catalyst.Server) {
} }
func assertRestore(t *testing.T, zipB []byte, server *catalyst.Server) { func assertRestore(t *testing.T, zipB []byte, server *catalyst.Server) {
t.Helper()
bodyBuf := &bytes.Buffer{} bodyBuf := &bytes.Buffer{}
bodyWriter := multipart.NewWriter(bodyBuf) bodyWriter := multipart.NewWriter(bodyBuf)
fileWriter, err := bodyWriter.CreateFormFile("backup", "backup.zip") fileWriter, err := bodyWriter.CreateFormFile("backup", "backup.zip")
@@ -166,7 +179,7 @@ func assertRestore(t *testing.T, zipB []byte, server *catalyst.Server) {
func createFile(ctx context.Context, server *catalyst.Server) { func createFile(ctx context.Context, server *catalyst.Server) {
buf := bytes.NewBufferString("test text") buf := bytes.NewBufferString("test text")
server.Storage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String("catalyst-8125")}) _, _ = server.Storage.S3().CreateBucket(&s3.CreateBucketInput{Bucket: pointer.String("catalyst-8125")})
if _, err := server.Storage.Uploader().Upload(&s3manager.UploadInput{Body: buf, Bucket: pointer.String("catalyst-8125"), Key: pointer.String("test.txt")}); err != nil { if _, err := server.Storage.Uploader().Upload(&s3manager.UploadInput{Body: buf, Bucket: pointer.String("catalyst-8125"), Key: pointer.String("test.txt")}); err != nil {
log.Fatal(err) log.Fatal(err)
@@ -178,6 +191,8 @@ func createFile(ctx context.Context, server *catalyst.Server) {
} }
func assertTicketExists(t *testing.T, server *catalyst.Server) { func assertTicketExists(t *testing.T, server *catalyst.Server) {
t.Helper()
req := httptest.NewRequest(http.MethodGet, "/api/tickets/8125", nil) req := httptest.NewRequest(http.MethodGet, "/api/tickets/8125", nil)
req.Header.Set("PRIVATE-TOKEN", "test") req.Header.Set("PRIVATE-TOKEN", "test")
@@ -202,6 +217,8 @@ func assertTicketExists(t *testing.T, server *catalyst.Server) {
} }
func assertFileExists(t *testing.T, server *catalyst.Server) { func assertFileExists(t *testing.T, server *catalyst.Server) {
t.Helper()
obj, err := server.Storage.S3().GetObject(&s3.GetObjectInput{ obj, err := server.Storage.S3().GetObject(&s3.GetObjectInput{
Bucket: aws.String("catalyst-8125"), Bucket: aws.String("catalyst-8125"),
Key: aws.String("test.txt"), Key: aws.String("test.txt"),
@@ -215,6 +232,8 @@ func assertFileExists(t *testing.T, server *catalyst.Server) {
} }
func includes(t *testing.T, names []string, s string) bool { func includes(t *testing.T, names []string, s string) bool {
t.Helper()
for _, name := range names { for _, name := range names {
match, err := regexp.MatchString(s, name) match, err := regexp.MatchString(s, name)
if err != nil { if err != nil {
@@ -225,10 +244,13 @@ func includes(t *testing.T, names []string, s string) bool {
return true return true
} }
} }
return false return false
} }
func readZipFile(t *testing.T, b []byte) *zip.Reader { func readZipFile(t *testing.T, b []byte) *zip.Reader {
t.Helper()
buf := bytes.NewReader(b) buf := bytes.NewReader(b)
zr, err := zip.NewReader(buf, int64(buf.Len())) zr, err := zip.NewReader(buf, int64(buf.Len()))
+6 -3
View File
@@ -10,9 +10,11 @@ import (
"github.com/SecurityBrewery/catalyst/generated/pointer" "github.com/SecurityBrewery/catalyst/generated/pointer"
) )
var bobSetting = &model.UserData{Email: pointer.String("bob@example.org"), Name: pointer.String("Bob Bad")} var (
var bobForm = &model.UserForm{ID: "bob", Blocked: false, Roles: []string{"admin"}} bobSetting = &model.UserData{Email: pointer.String("bob@example.org"), Name: pointer.String("Bob Bad")}
var Bob = &model.UserResponse{ID: "bob", Blocked: false, Roles: []string{"admin"}} bobForm = &model.UserForm{ID: "bob", Blocked: false, Roles: []string{"admin"}}
Bob = &model.UserResponse{ID: "bob", Blocked: false, Roles: []string{"admin"}}
)
func SetupTestData(ctx context.Context, db *database.Database) error { func SetupTestData(ctx context.Context, db *database.Database) error {
if err := db.UserDataCreate(ctx, "bob", bobSetting); err != nil { if err := db.UserDataCreate(ctx, "bob", bobSetting); err != nil {
@@ -109,5 +111,6 @@ func parse(s string) *time.Time {
} }
utc := modified.UTC() utc := modified.UTC()
return &utc return &utc
} }
+6 -1
View File
@@ -18,6 +18,8 @@ import (
) )
func TestJob(t *testing.T) { func TestJob(t *testing.T) {
t.Parallel()
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
_, _, server, err := Catalyst(t) _, _, server, err := Catalyst(t)
@@ -27,7 +29,7 @@ func TestJob(t *testing.T) {
b, err := json.Marshal(model.JobForm{ b, err := json.Marshal(model.JobForm{
Automation: "hash.sha1", Automation: "hash.sha1",
Payload: map[string]interface{}{"default": "test"}, Payload: map[string]any{"default": "test"},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -52,11 +54,14 @@ func TestJob(t *testing.T) {
output := gjson.GetBytes(job, "output.hash").String() output := gjson.GetBytes(job, "output.hash").String()
assert.Equal(t, "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3", output) assert.Equal(t, "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3", output)
break break
} }
} }
func request(t *testing.T, server chi.Router, method, url string, data io.Reader) []byte { func request(t *testing.T, server chi.Router, method, url string, data io.Reader) []byte {
t.Helper()
w := httptest.NewRecorder() w := httptest.NewRecorder()
// setup request // setup request
+9 -2
View File
@@ -25,10 +25,15 @@ func (testClock) Now() time.Time {
} }
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
t.Parallel()
ctime.DefaultClock = testClock{} ctime.DefaultClock = testClock{}
for _, tt := range api.Tests { for _, tt := range api.Tests {
tt := tt
t.Run(tt.Name, func(t *testing.T) { t.Run(tt.Name, func(t *testing.T) {
t.Parallel()
ctx, _, _, _, _, db, _, server, cleanup, err := Server(t) ctx, _, _, _, _, db, _, server, cleanup, err := Server(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -73,8 +78,10 @@ func TestServer(t *testing.T) {
} }
} }
func jsonEqual(t *testing.T, name string, got io.Reader, want interface{}) { func jsonEqual(t *testing.T, name string, got io.Reader, want any) {
var gotObject, wantObject interface{} t.Helper()
var gotObject, wantObject any
// load bytes // load bytes
wantBytes, err := json.Marshal(want) wantBytes, err := json.Marshal(want)
+18
View File
@@ -75,6 +75,8 @@ func Config(ctx context.Context) (*catalyst.Config, error) {
} }
func Index(t *testing.T) (*index.Index, func(), error) { func Index(t *testing.T) (*index.Index, func(), error) {
t.Helper()
dir, err := os.MkdirTemp("", "catalyst-test-"+cleanName(t)) dir, err := os.MkdirTemp("", "catalyst-test-"+cleanName(t))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -84,10 +86,13 @@ func Index(t *testing.T) (*index.Index, func(), error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return catalystIndex, func() { catalystIndex.Close(); os.RemoveAll(dir) }, nil return catalystIndex, func() { catalystIndex.Close(); os.RemoveAll(dir) }, nil
} }
func Bus(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, error) { func Bus(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, error) {
t.Helper()
ctx := Context() ctx := Context()
config, err := Config(ctx) config, err := Config(ctx)
@@ -99,10 +104,13 @@ func Bus(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, error) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return ctx, config, catalystBus, err return ctx, config, catalystBus, err
} }
func DB(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, func(), error) { func DB(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, func(), error) {
t.Helper()
ctx, config, rbus, err := Bus(t) ctx, config, rbus, err := Bus(t)
if err != nil { if err != nil {
return nil, nil, nil, nil, nil, nil, nil, err return nil, nil, nil, nil, nil, nil, nil, err
@@ -146,6 +154,8 @@ func DB(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index
} }
func Service(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, *service.Service, func(), error) { func Service(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, *service.Service, func(), error) {
t.Helper()
ctx, config, rbus, catalystIndex, catalystStorage, db, cleanup, err := DB(t) ctx, config, rbus, catalystIndex, catalystStorage, db, cleanup, err := DB(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -160,6 +170,8 @@ func Service(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.
} }
func Server(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, *service.Service, chi.Router, func(), error) { func Server(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.Index, *storage.Storage, *database.Database, *service.Service, chi.Router, func(), error) {
t.Helper()
ctx, config, rbus, catalystIndex, catalystStorage, db, catalystService, cleanup, err := Service(t) ctx, config, rbus, catalystIndex, catalystStorage, db, catalystService, cleanup, err := Service(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -177,6 +189,8 @@ func Server(t *testing.T) (context.Context, *catalyst.Config, *bus.Bus, *index.I
} }
func Catalyst(t *testing.T) (context.Context, *catalyst.Config, *catalyst.Server, error) { func Catalyst(t *testing.T) (context.Context, *catalyst.Config, *catalyst.Server, error) {
t.Helper()
ctx := Context() ctx := Context()
config, err := Config(ctx) config, err := Config(ctx)
@@ -189,13 +203,17 @@ func Catalyst(t *testing.T) (context.Context, *catalyst.Config, *catalyst.Server
c, err := catalyst.New(&hooks.Hooks{ c, err := catalyst.New(&hooks.Hooks{
DatabaseAfterConnectFuncs: []func(ctx context.Context, client driver.Client, name string){Clear}, DatabaseAfterConnectFuncs: []func(ctx context.Context, client driver.Client, name string){Clear},
}, config) }, config)
return ctx, config, c, err return ctx, config, c, err
} }
func cleanName(t *testing.T) string { func cleanName(t *testing.T) string {
t.Helper()
name := t.Name() name := t.Name()
name = strings.ReplaceAll(name, " ", "") name = strings.ReplaceAll(name, " ", "")
name = strings.ReplaceAll(name, "/", "_") name = strings.ReplaceAll(name, "/", "_")
return strings.ReplaceAll(name, "#", "_") return strings.ReplaceAll(name, "#", "_")
} }
+7 -2
View File
@@ -9,14 +9,16 @@ import (
) )
func TestUser(t *testing.T) { func TestUser(t *testing.T) {
t.Parallel()
type args struct { type args struct {
method string method string
url string url string
data interface{} data any
} }
type want struct { type want struct {
status int status int
body interface{} body any
} }
tests := []struct { tests := []struct {
name string name string
@@ -27,7 +29,10 @@ func TestUser(t *testing.T) {
{name: "ListUsers", args: args{method: http.MethodGet, url: "/users"}, want: want{status: http.StatusOK}}, {name: "ListUsers", args: args{method: http.MethodGet, url: "/users"}, want: want{status: http.StatusOK}},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, _, server, cleanup, err := Server(t) _, _, _, _, _, _, _, server, cleanup, err := Server(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
+16 -8
View File
@@ -3,15 +3,23 @@ package ui
import "testing" import "testing"
func TestUI(t *testing.T) { func TestUI(t *testing.T) {
requiredFiles := []string{ t.Parallel()
"dist/index.html",
"dist/favicon.ico", tests := []struct {
"dist/manifest.json", name string
"dist/img", path string
}{
{"index.html", "dist/index.html"},
{"favicon.ico", "dist/favicon.ico"},
{"manifest.json", "dist/manifest.json"},
{"img", "dist/img"},
} }
for _, requiredFile := range requiredFiles { for _, tt := range tests {
t.Run("Require "+requiredFile, func(t *testing.T) { tt := tt
f, err := UI.Open(requiredFile) t.Run(tt.name, func(t *testing.T) {
t.Parallel()
f, err := UI.Open(tt.path)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
+3 -1
View File
@@ -41,6 +41,7 @@ func (wb *websocketBroker) NewWebsocket() (string, chan []byte) {
wb.mu.Lock() wb.mu.Lock()
wb.clients[id] = channel wb.clients[id] = channel
wb.mu.Unlock() wb.mu.Unlock()
return id, channel return id, channel
} }
@@ -49,7 +50,7 @@ func handleWebSocket(catalystBus *bus.Bus) http.HandlerFunc {
// send all messages from bus to websocket // send all messages from bus to websocket
err := catalystBus.SubscribeDatabaseUpdate(func(msg *bus.DatabaseUpdateMsg) { err := catalystBus.SubscribeDatabaseUpdate(func(msg *bus.DatabaseUpdateMsg) {
b, err := json.Marshal(map[string]interface{}{ b, err := json.Marshal(map[string]any{
"action": "update", "action": "update",
"ids": msg.IDs, "ids": msg.IDs,
}) })
@@ -67,6 +68,7 @@ func handleWebSocket(catalystBus *bus.Bus) http.HandlerFunc {
conn, _, _, err := ws.UpgradeHTTP(r, w) conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil { if err != nil {
api.JSONError(w, errors.New("upgrade failed")) api.JSONError(w, errors.New("upgrade failed"))
return return
} }