mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-06 15:22:47 +01:00
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/SecurityBrewery/catalyst/generated/caql/parser"
|
||||
)
|
||||
|
||||
var TooComplexError = errors.New("unsupported features for index queries, use advanced search instead")
|
||||
var ErrTooComplex = errors.New("unsupported features for index queries, use advanced search instead")
|
||||
|
||||
type bleveBuilder struct {
|
||||
*parser.BaseCAQLParserListener
|
||||
@@ -35,8 +35,9 @@ func (s *bleveBuilder) pop() (n string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *bleveBuilder) binaryPop() (interface{}, interface{}) {
|
||||
func (s *bleveBuilder) binaryPop() (any, any) {
|
||||
right, left := s.pop(), s.pop()
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
@@ -48,9 +49,7 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
case ctx.Reference() != nil:
|
||||
// pass
|
||||
case ctx.Operator_unary() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_PLUS() != nil:
|
||||
fallthrough
|
||||
case ctx.T_MINUS() != nil:
|
||||
@@ -60,13 +59,9 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
case ctx.T_DIV() != nil:
|
||||
fallthrough
|
||||
case ctx.T_MOD() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_RANGE() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s:<%s", left, right))
|
||||
@@ -79,64 +74,46 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s:>=%s", left, right))
|
||||
|
||||
case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s:%s", left, right))
|
||||
case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("-%s:%s", left, right))
|
||||
|
||||
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
|
||||
fallthrough
|
||||
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
|
||||
fallthrough
|
||||
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
fallthrough
|
||||
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
fallthrough
|
||||
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_LIKE() != nil:
|
||||
s.err = errors.New("index queries are like queries by default")
|
||||
return
|
||||
|
||||
case ctx.T_REGEX_MATCH() != nil:
|
||||
left, right := s.binaryPop()
|
||||
if ctx.T_NOT() != nil {
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
} else {
|
||||
s.push(fmt.Sprintf("%s:/%s/", left, right))
|
||||
}
|
||||
case ctx.T_REGEX_NON_MATCH() != nil:
|
||||
s.err = errors.New("index query cannot contain regex non matches, use advanced search instead")
|
||||
return
|
||||
|
||||
case ctx.T_AND() != nil:
|
||||
left, right := s.binaryPop()
|
||||
s.push(fmt.Sprintf("%s %s", left, right))
|
||||
case ctx.T_OR() != nil:
|
||||
s.err = errors.New("index query cannot contain OR, use advanced search instead")
|
||||
return
|
||||
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
|
||||
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
|
||||
return
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
|
||||
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
|
||||
return
|
||||
|
||||
default:
|
||||
panic("unknown expression")
|
||||
}
|
||||
@@ -152,17 +129,13 @@ func (s *bleveBuilder) ExitReference(ctx *parser.ReferenceContext) {
|
||||
case ctx.T_STRING() != nil:
|
||||
s.push(ctx.T_STRING().GetText())
|
||||
case ctx.Compound_value() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
case ctx.Function_call() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_OPEN() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
case ctx.T_ARRAY_OPEN() != nil:
|
||||
s.err = TooComplexError
|
||||
return
|
||||
s.err = ErrTooComplex
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package caql
|
||||
package caql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/caql"
|
||||
)
|
||||
|
||||
func TestBleveBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saql string
|
||||
@@ -18,15 +22,20 @@ func TestBleveBuilder(t *testing.T) {
|
||||
{name: "Search 4", saql: `title == 'malware' AND 'wannacry'`, wantBleve: `title:"malware" "wannacry"`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
parser := &Parser{}
|
||||
tt := tt
|
||||
|
||||
parser := &caql.Parser{}
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expr, err := parser.Parse(tt.saql)
|
||||
if (err != nil) != tt.wantParseErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
if expr != nil {
|
||||
t.Error(expr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -37,6 +46,7 @@ func TestBleveBuilder(t *testing.T) {
|
||||
if (err != nil) != tt.wantRebuildErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/generated/caql/parser"
|
||||
)
|
||||
|
||||
@@ -40,6 +42,7 @@ func (s *aqlBuilder) pop() (n string) {
|
||||
|
||||
func (s *aqlBuilder) binaryPop() (string, string) {
|
||||
right, left := s.pop(), s.pop()
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
@@ -181,8 +184,10 @@ func (s *aqlBuilder) toBoolString(v string) string {
|
||||
if err != nil {
|
||||
panic("invalid search " + err.Error())
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`d._key IN ["%s"]`, strings.Join(ids, `","`))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -246,7 +251,7 @@ func (s *aqlBuilder) ExitFunction_call(ctx *parser.Function_callContext) {
|
||||
}
|
||||
parameter := strings.Join(array, ", ")
|
||||
|
||||
if !stringSliceContains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
|
||||
if !slices.Contains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
|
||||
panic("unknown function")
|
||||
}
|
||||
|
||||
|
||||
136
caql/function.go
136
caql/function.go
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
switch strings.ToUpper(ctx.T_STRING().GetText()) {
|
||||
|
||||
default:
|
||||
s.appendErrors(errors.New("unknown function"))
|
||||
|
||||
@@ -26,8 +25,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
if len(ctx.AllExpression()) == 3 {
|
||||
u = s.pop().(bool)
|
||||
}
|
||||
seen := map[interface{}]bool{}
|
||||
values, anyArray := s.pop().([]interface{}), s.pop().([]interface{})
|
||||
seen := map[any]bool{}
|
||||
values, anyArray := s.pop().([]any), s.pop().([]any)
|
||||
|
||||
if u {
|
||||
for _, e := range anyArray {
|
||||
@@ -45,18 +44,18 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(anyArray)
|
||||
case "COUNT_DISTINCT", "COUNT_UNIQUE":
|
||||
count := 0
|
||||
seen := map[interface{}]bool{}
|
||||
array := s.pop().([]interface{})
|
||||
seen := map[any]bool{}
|
||||
array := s.pop().([]any)
|
||||
for _, e := range array {
|
||||
_, ok := seen[e]
|
||||
if !ok {
|
||||
seen[e] = true
|
||||
count += 1
|
||||
count++
|
||||
}
|
||||
}
|
||||
s.push(float64(count))
|
||||
case "FIRST":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if len(array) == 0 {
|
||||
s.push(nil)
|
||||
} else {
|
||||
@@ -65,16 +64,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "FLATTEN":
|
||||
// case "INTERLEAVE":
|
||||
case "INTERSECTION":
|
||||
iset := New(s.pop().([]interface{})...)
|
||||
iset := NewSet(s.pop().([]any)...)
|
||||
|
||||
for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
iset = iset.Intersection(New(s.pop().([]interface{})...))
|
||||
iset = iset.Intersection(NewSet(s.pop().([]any)...))
|
||||
}
|
||||
|
||||
s.push(iset.Values())
|
||||
// case "JACCARD":
|
||||
case "LAST":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if len(array) == 0 {
|
||||
s.push(nil)
|
||||
} else {
|
||||
@@ -94,9 +93,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(float64(len(fmt.Sprint(v))))
|
||||
case string:
|
||||
s.push(float64(utf8.RuneCountInString(v)))
|
||||
case []interface{}:
|
||||
case []any:
|
||||
s.push(float64(len(v)))
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
s.push(float64(len(v)))
|
||||
default:
|
||||
panic("unknown type")
|
||||
@@ -104,7 +103,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "MINUS":
|
||||
var sets []*Set
|
||||
for i := 0; i < len(ctx.AllExpression()); i++ {
|
||||
sets = append(sets, New(s.pop().([]interface{})...))
|
||||
sets = append(sets, NewSet(s.pop().([]any)...))
|
||||
}
|
||||
|
||||
iset := sets[len(sets)-1]
|
||||
@@ -116,7 +115,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(iset.Values())
|
||||
case "NTH":
|
||||
pos := s.pop().(float64)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if int(pos) >= len(array) || pos < 0 {
|
||||
s.push(nil)
|
||||
} else {
|
||||
@@ -124,16 +123,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
// case "OUTERSECTION":
|
||||
// array := s.pop().([]interface{})
|
||||
// union := New(array...)
|
||||
// intersection := New(s.pop().([]interface{})...)
|
||||
// union := NewSet(array...)
|
||||
// intersection := NewSet(s.pop().([]interface{})...)
|
||||
// for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
// array = s.pop().([]interface{})
|
||||
// union = union.Union(New(array...))
|
||||
// intersection = intersection.Intersection(New(array...))
|
||||
// union = union.Union(NewSet(array...))
|
||||
// intersection = intersection.Intersection(NewSet(array...))
|
||||
// }
|
||||
// s.push(union.Minus(intersection).Values())
|
||||
case "POP":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
s.push(array[:len(array)-1])
|
||||
case "POSITION", "CONTAINS_ARRAY":
|
||||
returnIndex := false
|
||||
@@ -141,7 +140,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
returnIndex = s.pop().(bool)
|
||||
}
|
||||
search := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
for idx, e := range array {
|
||||
if e == search {
|
||||
@@ -164,7 +163,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
u = s.pop().(bool)
|
||||
}
|
||||
element := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
if u && contains(array, element) {
|
||||
s.push(array)
|
||||
@@ -173,13 +172,13 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
case "REMOVE_NTH":
|
||||
position := s.pop().(float64)
|
||||
anyArray := s.pop().([]interface{})
|
||||
anyArray := s.pop().([]any)
|
||||
|
||||
if position < 0 {
|
||||
position = float64(len(anyArray) + int(position))
|
||||
}
|
||||
|
||||
result := []interface{}{}
|
||||
result := []any{}
|
||||
for idx, e := range anyArray {
|
||||
if idx != int(position) {
|
||||
result = append(result, e)
|
||||
@@ -193,7 +192,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
replaceValue := s.pop().(string)
|
||||
position := s.pop().(float64)
|
||||
anyArray := s.pop().([]interface{})
|
||||
anyArray := s.pop().([]any)
|
||||
|
||||
if position < 0 {
|
||||
position = float64(len(anyArray) + int(position))
|
||||
@@ -224,8 +223,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
limit = s.pop().(float64)
|
||||
}
|
||||
value := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
result := []interface{}{}
|
||||
array := s.pop().([]any)
|
||||
result := []any{}
|
||||
for idx, e := range array {
|
||||
if e != value || float64(idx) > limit {
|
||||
result = append(result, e)
|
||||
@@ -233,9 +232,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(result)
|
||||
case "REMOVE_VALUES":
|
||||
values := s.pop().([]interface{})
|
||||
array := s.pop().([]interface{})
|
||||
result := []interface{}{}
|
||||
values := s.pop().([]any)
|
||||
array := s.pop().([]any)
|
||||
result := []any{}
|
||||
for _, e := range array {
|
||||
if !contains(values, e) {
|
||||
result = append(result, e)
|
||||
@@ -243,14 +242,14 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(result)
|
||||
case "REVERSE":
|
||||
array := s.pop().([]interface{})
|
||||
var reverse []interface{}
|
||||
array := s.pop().([]any)
|
||||
var reverse []any
|
||||
for _, e := range array {
|
||||
reverse = append([]interface{}{e}, reverse...)
|
||||
reverse = append([]any{e}, reverse...)
|
||||
}
|
||||
s.push(reverse)
|
||||
case "SHIFT":
|
||||
s.push(s.pop().([]interface{})[1:])
|
||||
s.push(s.pop().([]any)[1:])
|
||||
case "SLICE":
|
||||
length := float64(-1)
|
||||
full := true
|
||||
@@ -259,7 +258,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
full = false
|
||||
}
|
||||
start := int64(s.pop().(float64))
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
if start < 0 {
|
||||
start = int64(len(array)) + start
|
||||
@@ -276,43 +275,43 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(array[start:end])
|
||||
case "SORTED":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
|
||||
s.push(array)
|
||||
case "SORTED_UNIQUE":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
|
||||
s.push(unique(array))
|
||||
case "UNION":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
|
||||
for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
array = append(array, s.pop().([]interface{})...)
|
||||
array = append(array, s.pop().([]any)...)
|
||||
}
|
||||
|
||||
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
|
||||
s.push(array)
|
||||
case "UNION_DISTINCT":
|
||||
iset := New(s.pop().([]interface{})...)
|
||||
iset := NewSet(s.pop().([]any)...)
|
||||
|
||||
for i := 1; i < len(ctx.AllExpression()); i++ {
|
||||
iset = iset.Union(New(s.pop().([]interface{})...))
|
||||
iset = iset.Union(NewSet(s.pop().([]any)...))
|
||||
}
|
||||
|
||||
s.push(unique(iset.Values()))
|
||||
case "UNIQUE":
|
||||
s.push(unique(s.pop().([]interface{})))
|
||||
s.push(unique(s.pop().([]any)))
|
||||
case "UNSHIFT":
|
||||
u := false
|
||||
if len(ctx.AllExpression()) == 3 {
|
||||
u = s.pop().(bool)
|
||||
}
|
||||
element := s.pop()
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
if u && contains(array, element) {
|
||||
s.push(array)
|
||||
} else {
|
||||
s.push(append([]interface{}{element}, array...))
|
||||
s.push(append([]any{element}, array...))
|
||||
}
|
||||
|
||||
// Bit https://www.arangodb.com/docs/stable/aql/functions-bit.html
|
||||
@@ -367,8 +366,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
if len(ctx.AllExpression()) >= 2 {
|
||||
removeInternal = s.pop().(bool)
|
||||
}
|
||||
var keys []interface{}
|
||||
for k := range s.pop().(map[string]interface{}) {
|
||||
var keys []any
|
||||
for k := range s.pop().(map[string]any) {
|
||||
isInternalKey := strings.HasPrefix(k, "_")
|
||||
if !removeInternal || !isInternalKey {
|
||||
keys = append(keys, k)
|
||||
@@ -379,20 +378,20 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "COUNT":
|
||||
case "HAS":
|
||||
right, left := s.pop(), s.pop()
|
||||
_, ok := left.(map[string]interface{})[right.(string)]
|
||||
_, ok := left.(map[string]any)[right.(string)]
|
||||
s.push(ok)
|
||||
// case "KEEP":
|
||||
// case "LENGTH":
|
||||
// case "MATCHES":
|
||||
case "MERGE":
|
||||
var docs []map[string]interface{}
|
||||
var docs []map[string]any
|
||||
if len(ctx.AllExpression()) == 1 {
|
||||
for _, doc := range s.pop().([]interface{}) {
|
||||
docs = append([]map[string]interface{}{doc.(map[string]interface{})}, docs...)
|
||||
for _, doc := range s.pop().([]any) {
|
||||
docs = append([]map[string]any{doc.(map[string]any)}, docs...)
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < len(ctx.AllExpression()); i++ {
|
||||
docs = append(docs, s.pop().(map[string]interface{}))
|
||||
docs = append(docs, s.pop().(map[string]any))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,9 +403,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
}
|
||||
s.push(doc)
|
||||
case "MERGE_RECURSIVE":
|
||||
var doc map[string]interface{}
|
||||
var doc map[string]any
|
||||
for i := 0; i < len(ctx.AllExpression()); i++ {
|
||||
err := mergo.Merge(&doc, s.pop().(map[string]interface{}))
|
||||
err := mergo.Merge(&doc, s.pop().(map[string]any))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -421,8 +420,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
if len(ctx.AllExpression()) == 2 {
|
||||
removeInternal = s.pop().(bool)
|
||||
}
|
||||
var values []interface{}
|
||||
for k, v := range s.pop().(map[string]interface{}) {
|
||||
var values []any
|
||||
for k, v := range s.pop().(map[string]any) {
|
||||
isInternalKey := strings.HasPrefix(k, "_")
|
||||
if !removeInternal || !isInternalKey {
|
||||
values = append(values, v)
|
||||
@@ -458,10 +457,10 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "AVERAGE", "AVG":
|
||||
count := 0
|
||||
sum := float64(0)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
count += 1
|
||||
count++
|
||||
sum += toNumber(element)
|
||||
}
|
||||
}
|
||||
@@ -506,7 +505,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "MAX":
|
||||
var set bool
|
||||
var max float64
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
if !set || toNumber(element) > max {
|
||||
@@ -521,7 +520,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(nil)
|
||||
}
|
||||
case "MEDIAN":
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
var numbers []float64
|
||||
for _, element := range array {
|
||||
if f, ok := element.(float64); ok {
|
||||
@@ -544,7 +543,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "MIN":
|
||||
var set bool
|
||||
var min float64
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
if !set || toNumber(element) < min {
|
||||
@@ -566,7 +565,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
s.push(math.Pow(left.(float64), right.(float64)))
|
||||
case "PRODUCT":
|
||||
product := float64(1)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
if element != nil {
|
||||
product *= toNumber(element)
|
||||
@@ -578,7 +577,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
case "RAND":
|
||||
s.push(rand.Float64())
|
||||
case "RANGE":
|
||||
var array []interface{}
|
||||
var array []any
|
||||
var start, end, step float64
|
||||
if len(ctx.AllExpression()) == 2 {
|
||||
right, left := s.pop(), s.pop()
|
||||
@@ -612,7 +611,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "STDDEV":
|
||||
case "SUM":
|
||||
sum := float64(0)
|
||||
array := s.pop().([]interface{})
|
||||
array := s.pop().([]any)
|
||||
for _, element := range array {
|
||||
sum += toNumber(element)
|
||||
}
|
||||
@@ -691,7 +690,6 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
|
||||
// case "IS_IPV4":
|
||||
// case "IS_KEY":
|
||||
// case "TYPENAME":
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -705,6 +703,7 @@ func unique(array []interface{}) []interface{} {
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
@@ -714,15 +713,7 @@ func contains(values []interface{}, e interface{}) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringSliceContains(values []string, e string) bool {
|
||||
for _, v := range values {
|
||||
if e == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -747,4 +738,5 @@ var functionNames = []string{
|
||||
"REGEX_REPLACE", "REVERSE", "RIGHT", "RTRIM", "SHA1", "SHA512", "SOUNDEX", "SPLIT", "STARTS_WITH", "SUBSTITUTE",
|
||||
"SUBSTRING", "TOKENS", "TO_BASE64", "TO_HEX", "TRIM", "UPPER", "UUID", "TO_BOOL", "TO_NUMBER", "TO_STRING",
|
||||
"TO_ARRAY", "TO_LIST", "IS_NULL", "IS_BOOL", "IS_NUMBER", "IS_STRING", "IS_ARRAY", "IS_LIST", "IS_OBJECT",
|
||||
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME"}
|
||||
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME",
|
||||
}
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
package caql
|
||||
package caql_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/caql"
|
||||
)
|
||||
|
||||
func TestFunctions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saql string
|
||||
wantRebuild string
|
||||
wantValue interface{}
|
||||
wantValue any
|
||||
wantParseErr bool
|
||||
wantRebuildErr bool
|
||||
wantEvalErr bool
|
||||
@@ -266,13 +270,13 @@ func TestFunctions(t *testing.T) {
|
||||
{name: "RADIANS", saql: `RADIANS(0)`, wantRebuild: `RADIANS(0)`, wantValue: 0},
|
||||
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.3503170117504508},
|
||||
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.6138226173882478},
|
||||
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []interface{}{float64(1), float64(2), float64(3), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []interface{}{float64(1), float64(3)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []interface{}{float64(1), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []interface{}{float64(1), float64(2)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []interface{}{1.5, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []interface{}{1.5, 2.0, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []interface{}{-0.75, -0.25, 0.25, 0.75}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []any{float64(1), float64(2), float64(3), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []any{float64(1), float64(3)}},
|
||||
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []any{float64(1), float64(4)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []any{float64(1), float64(2)}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []any{1.5, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []any{1.5, 2.0, 2.5}},
|
||||
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []any{-0.75, -0.25, 0.25, 0.75}},
|
||||
{name: "ROUND", saql: `ROUND(2.49)`, wantRebuild: `ROUND(2.49)`, wantValue: 2},
|
||||
{name: "ROUND", saql: `ROUND(2.50)`, wantRebuild: `ROUND(2.50)`, wantValue: 3},
|
||||
{name: "ROUND", saql: `ROUND(-2.50)`, wantRebuild: `ROUND(-2.50)`, wantValue: -2},
|
||||
@@ -299,15 +303,20 @@ func TestFunctions(t *testing.T) {
|
||||
{name: "Function Error 3", saql: `ABS("abs")`, wantRebuild: `ABS("abs")`, wantEvalErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
parser := &Parser{}
|
||||
tt := tt
|
||||
|
||||
parser := &caql.Parser{}
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expr, err := parser.Parse(tt.saql)
|
||||
if (err != nil) != tt.wantParseErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
if expr != nil {
|
||||
t.Error(expr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -318,6 +327,7 @@ func TestFunctions(t *testing.T) {
|
||||
if (err != nil) != tt.wantRebuildErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -327,18 +337,19 @@ func TestFunctions(t *testing.T) {
|
||||
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
|
||||
}
|
||||
|
||||
var myJson map[string]interface{}
|
||||
var myJSON map[string]any
|
||||
if tt.values != "" {
|
||||
err = json.Unmarshal([]byte(tt.values), &myJson)
|
||||
err = json.Unmarshal([]byte(tt.values), &myJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
value, err := expr.Eval(myJson)
|
||||
value, err := expr.Eval(myJSON)
|
||||
if (err != nil) != tt.wantEvalErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -367,14 +378,15 @@ func TestFunctions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func jsonParse(s string) interface{} {
|
||||
func jsonParse(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var j interface{}
|
||||
var j any
|
||||
err := json.Unmarshal([]byte(s), &j)
|
||||
if err != nil {
|
||||
panic(s + err.Error())
|
||||
}
|
||||
|
||||
return j
|
||||
}
|
||||
|
||||
@@ -10,22 +10,23 @@ import (
|
||||
|
||||
type aqlInterpreter struct {
|
||||
*parser.BaseCAQLParserListener
|
||||
values map[string]interface{}
|
||||
stack []interface{}
|
||||
values map[string]any
|
||||
stack []any
|
||||
errs []error
|
||||
}
|
||||
|
||||
// push is a helper function for pushing new node to the listener Stack.
|
||||
func (s *aqlInterpreter) push(i interface{}) {
|
||||
func (s *aqlInterpreter) push(i any) {
|
||||
s.stack = append(s.stack, i)
|
||||
}
|
||||
|
||||
// pop is a helper function for poping a node from the listener Stack.
|
||||
func (s *aqlInterpreter) pop() (n interface{}) {
|
||||
func (s *aqlInterpreter) pop() (n any) {
|
||||
// Check that we have nodes in the stack.
|
||||
size := len(s.stack)
|
||||
if size < 1 {
|
||||
s.appendErrors(ErrStack)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -35,8 +36,9 @@ func (s *aqlInterpreter) pop() (n interface{}) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *aqlInterpreter) binaryPop() (interface{}, interface{}) {
|
||||
func (s *aqlInterpreter) binaryPop() (any, any) {
|
||||
right, left := s.pop(), s.pop()
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
@@ -54,17 +56,14 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
s.push(plus(s.binaryPop()))
|
||||
case ctx.T_MINUS() != nil:
|
||||
s.push(minus(s.binaryPop()))
|
||||
|
||||
case ctx.T_TIMES() != nil:
|
||||
s.push(times(s.binaryPop()))
|
||||
case ctx.T_DIV() != nil:
|
||||
s.push(div(s.binaryPop()))
|
||||
case ctx.T_MOD() != nil:
|
||||
s.push(mod(s.binaryPop()))
|
||||
|
||||
case ctx.T_RANGE() != nil:
|
||||
s.push(aqlrange(s.binaryPop()))
|
||||
|
||||
case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(lt(s.binaryPop()))
|
||||
case ctx.T_GT() != nil && ctx.GetEq_op() == nil:
|
||||
@@ -73,35 +72,30 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
s.push(le(s.binaryPop()))
|
||||
case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(ge(s.binaryPop()))
|
||||
|
||||
case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(maybeNot(ctx, in(s.binaryPop())))
|
||||
|
||||
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(eq(s.binaryPop()))
|
||||
case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
|
||||
s.push(ne(s.binaryPop()))
|
||||
|
||||
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(all(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
s.push(all(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(any(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
s.push(anyElement(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(none(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
|
||||
s.push(none(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
|
||||
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(all(left.([]interface{}), in, right))
|
||||
s.push(all(left.([]any), in, right))
|
||||
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(any(left.([]interface{}), in, right))
|
||||
s.push(anyElement(left.([]any), in, right))
|
||||
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(none(left.([]interface{}), in, right))
|
||||
|
||||
s.push(none(left.([]any), in, right))
|
||||
case ctx.T_LIKE() != nil:
|
||||
m, err := like(s.binaryPop())
|
||||
s.appendErrors(err)
|
||||
@@ -114,21 +108,18 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
|
||||
m, err := regexNonMatch(s.binaryPop())
|
||||
s.appendErrors(err)
|
||||
s.push(maybeNot(ctx, m))
|
||||
|
||||
case ctx.T_AND() != nil:
|
||||
s.push(and(s.binaryPop()))
|
||||
case ctx.T_OR() != nil:
|
||||
s.push(or(s.binaryPop()))
|
||||
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
|
||||
right, middle, left := s.pop(), s.pop(), s.pop()
|
||||
s.push(ternary(left, middle, right))
|
||||
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
|
||||
right, left := s.pop(), s.pop()
|
||||
s.push(ternary(left, nil, right))
|
||||
|
||||
default:
|
||||
panic("unkown expression")
|
||||
panic("unknown expression")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,7 +150,7 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
|
||||
case ctx.DOT() != nil:
|
||||
reference := s.pop()
|
||||
|
||||
s.push(reference.(map[string]interface{})[ctx.T_STRING().GetText()])
|
||||
s.push(reference.(map[string]any)[ctx.T_STRING().GetText()])
|
||||
case ctx.T_STRING() != nil:
|
||||
s.push(s.getVar(ctx.T_STRING().GetText()))
|
||||
case ctx.Compound_value() != nil:
|
||||
@@ -175,14 +166,15 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
|
||||
if f, ok := key.(float64); ok {
|
||||
index := int(f)
|
||||
if index < 0 {
|
||||
index = len(reference.([]interface{})) + index
|
||||
index = len(reference.([]any)) + index
|
||||
}
|
||||
|
||||
s.push(reference.([]interface{})[index])
|
||||
s.push(reference.([]any)[index])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.push(reference.(map[string]interface{})[key.(string)])
|
||||
s.push(reference.(map[string]any)[key.(string)])
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
|
||||
}
|
||||
@@ -239,17 +231,17 @@ func (s *aqlInterpreter) ExitValue_literal(ctx *parser.Value_literalContext) {
|
||||
|
||||
// ExitArray is called when production array is exited.
|
||||
func (s *aqlInterpreter) ExitArray(ctx *parser.ArrayContext) {
|
||||
array := []interface{}{}
|
||||
array := []any{}
|
||||
for range ctx.AllExpression() {
|
||||
// prepend element
|
||||
array = append([]interface{}{s.pop()}, array...)
|
||||
array = append([]any{s.pop()}, array...)
|
||||
}
|
||||
s.push(array)
|
||||
}
|
||||
|
||||
// ExitObject is called when production object is exited.
|
||||
func (s *aqlInterpreter) ExitObject(ctx *parser.ObjectContext) {
|
||||
object := map[string]interface{}{}
|
||||
object := map[string]any{}
|
||||
for range ctx.AllObject_element() {
|
||||
key, value := s.pop(), s.pop()
|
||||
|
||||
@@ -290,7 +282,7 @@ func (s *aqlInterpreter) ExitObject_element_name(ctx *parser.Object_element_name
|
||||
}
|
||||
}
|
||||
|
||||
func (s *aqlInterpreter) getVar(identifier string) interface{} {
|
||||
func (s *aqlInterpreter) getVar(identifier string) any {
|
||||
v, ok := s.values[identifier]
|
||||
if !ok {
|
||||
s.appendErrors(ErrUndefined)
|
||||
@@ -303,10 +295,11 @@ func maybeNot(ctx *parser.ExpressionContext, m bool) bool {
|
||||
if ctx.T_NOT() != nil {
|
||||
return !m
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func getOp(tokenType int) func(left, right interface{}) bool {
|
||||
func getOp(tokenType int) func(left, right any) bool {
|
||||
switch tokenType {
|
||||
case parser.CAQLLexerT_EQ:
|
||||
return eq
|
||||
@@ -323,33 +316,36 @@ func getOp(tokenType int) func(left, right interface{}) bool {
|
||||
case parser.CAQLLexerT_IN:
|
||||
return in
|
||||
default:
|
||||
panic("unkown token type")
|
||||
panic("unknown token type")
|
||||
}
|
||||
}
|
||||
|
||||
func all(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
|
||||
func all(slice []any, op func(any, any) bool, expr any) bool {
|
||||
for _, e := range slice {
|
||||
if !op(e, expr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func any(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
|
||||
func anyElement(slice []any, op func(any, any) bool, expr any) bool {
|
||||
for _, e := range slice {
|
||||
if op(e, expr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func none(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
|
||||
func none(slice []any, op func(any, any) bool, expr any) bool {
|
||||
for _, e := range slice {
|
||||
if op(e, expr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -10,21 +10,23 @@ import (
|
||||
|
||||
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
|
||||
|
||||
func or(left, right interface{}) interface{} {
|
||||
func or(left, right any) any {
|
||||
if toBool(left) {
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
func and(left, right interface{}) interface{} {
|
||||
func and(left, right any) any {
|
||||
if !toBool(left) {
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
func toBool(i interface{}) bool {
|
||||
func toBool(i any) bool {
|
||||
switch v := i.(type) {
|
||||
case nil:
|
||||
return false
|
||||
@@ -36,9 +38,9 @@ func toBool(i interface{}) bool {
|
||||
return v != 0
|
||||
case string:
|
||||
return v != ""
|
||||
case []interface{}:
|
||||
case []any:
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
return true
|
||||
default:
|
||||
panic("bool conversion failed")
|
||||
@@ -47,15 +49,15 @@ func toBool(i interface{}) bool {
|
||||
|
||||
// Arithmetic operators https://www.arangodb.com/docs/3.7/aql/operators.html#arithmetic-operators
|
||||
|
||||
func plus(left, right interface{}) float64 {
|
||||
func plus(left, right any) float64 {
|
||||
return toNumber(left) + toNumber(right)
|
||||
}
|
||||
|
||||
func minus(left, right interface{}) float64 {
|
||||
func minus(left, right any) float64 {
|
||||
return toNumber(left) - toNumber(right)
|
||||
}
|
||||
|
||||
func times(left, right interface{}) float64 {
|
||||
func times(left, right any) float64 {
|
||||
return round(toNumber(left) * toNumber(right))
|
||||
}
|
||||
|
||||
@@ -63,19 +65,20 @@ func round(r float64) float64 {
|
||||
return math.Round(r*100000) / 100000
|
||||
}
|
||||
|
||||
func div(left, right interface{}) float64 {
|
||||
func div(left, right any) float64 {
|
||||
b := toNumber(right)
|
||||
if b == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return round(toNumber(left) / b)
|
||||
}
|
||||
|
||||
func mod(left, right interface{}) float64 {
|
||||
func mod(left, right any) float64 {
|
||||
return math.Mod(toNumber(left), toNumber(right))
|
||||
}
|
||||
|
||||
func toNumber(i interface{}) float64 {
|
||||
func toNumber(i any) float64 {
|
||||
switch v := i.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
@@ -83,6 +86,7 @@ func toNumber(i interface{}) float64 {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
case float64:
|
||||
switch {
|
||||
@@ -91,22 +95,25 @@ func toNumber(i interface{}) float64 {
|
||||
case math.IsInf(v, 0):
|
||||
return 0
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return f
|
||||
case []interface{}:
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return 0
|
||||
}
|
||||
if len(v) == 1 {
|
||||
return toNumber(v[0])
|
||||
}
|
||||
|
||||
return 0
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
return 0
|
||||
default:
|
||||
panic("number conversion error")
|
||||
@@ -116,7 +123,7 @@ func toNumber(i interface{}) float64 {
|
||||
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
|
||||
// Order https://www.arangodb.com/docs/3.7/aql/fundamentals-type-value-order.html
|
||||
|
||||
func eq(left, right interface{}) bool {
|
||||
func eq(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return false
|
||||
@@ -126,15 +133,15 @@ func eq(left, right interface{}) bool {
|
||||
return true
|
||||
case bool, float64, string:
|
||||
return left == right
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -146,13 +153,14 @@ func eq(left, right interface{}) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -164,17 +172,18 @@ func eq(left, right interface{}) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func ne(left, right interface{}) bool {
|
||||
func ne(left, right any) bool {
|
||||
return !eq(left, right)
|
||||
}
|
||||
|
||||
func lt(left, right interface{}) bool {
|
||||
func lt(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV < rightV
|
||||
@@ -190,15 +199,15 @@ func lt(left, right interface{}) bool {
|
||||
return l < right.(float64)
|
||||
case string:
|
||||
return l < right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -210,13 +219,14 @@ func lt(left, right interface{}) bool {
|
||||
return lt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -228,16 +238,17 @@ func lt(left, right interface{}) bool {
|
||||
return lt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func keys(l map[string]interface{}, ro map[string]interface{}) []string {
|
||||
func keys(l map[string]any, ro map[string]any) []string {
|
||||
var keys []string
|
||||
seen := map[string]bool{}
|
||||
for _, a := range []map[string]interface{}{l, ro} {
|
||||
for _, a := range []map[string]any{l, ro} {
|
||||
for k := range a {
|
||||
if _, ok := seen[k]; !ok {
|
||||
seen[k] = true
|
||||
@@ -246,10 +257,11 @@ func keys(l map[string]interface{}, ro map[string]interface{}) []string {
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func gt(left, right interface{}) bool {
|
||||
func gt(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV > rightV
|
||||
@@ -265,15 +277,15 @@ func gt(left, right interface{}) bool {
|
||||
return l > right.(float64)
|
||||
case string:
|
||||
return l > right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -285,13 +297,14 @@ func gt(left, right interface{}) bool {
|
||||
return gt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -303,13 +316,14 @@ func gt(left, right interface{}) bool {
|
||||
return gt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func le(left, right interface{}) bool {
|
||||
func le(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV <= rightV
|
||||
@@ -325,15 +339,15 @@ func le(left, right interface{}) bool {
|
||||
return l <= right.(float64)
|
||||
case string:
|
||||
return l <= right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -345,13 +359,14 @@ func le(left, right interface{}) bool {
|
||||
return le(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -363,13 +378,14 @@ func le(left, right interface{}) bool {
|
||||
return lt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func ge(left, right interface{}) bool {
|
||||
func ge(left, right any) bool {
|
||||
leftV, rightV := typeValue(left), typeValue(right)
|
||||
if leftV != rightV {
|
||||
return leftV >= rightV
|
||||
@@ -385,15 +401,15 @@ func ge(left, right interface{}) bool {
|
||||
return l >= right.(float64)
|
||||
case string:
|
||||
return l >= right.(string)
|
||||
case []interface{}:
|
||||
ra := right.([]interface{})
|
||||
case []any:
|
||||
ra := right.([]any)
|
||||
max := len(l)
|
||||
if len(ra) > max {
|
||||
max = len(ra)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if len(l) > i {
|
||||
li = l[i]
|
||||
}
|
||||
@@ -405,13 +421,14 @@ func ge(left, right interface{}) bool {
|
||||
return ge(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
ro := right.(map[string]interface{})
|
||||
case map[string]any:
|
||||
ro := right.(map[string]any)
|
||||
|
||||
for _, key := range keys(l, ro) {
|
||||
var li interface{} = nil
|
||||
var rai interface{} = nil
|
||||
var li any
|
||||
var rai any
|
||||
if lv, ok := l[key]; ok {
|
||||
li = lv
|
||||
}
|
||||
@@ -423,14 +440,15 @@ func ge(left, right interface{}) bool {
|
||||
return gt(li, rai)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func in(left, right interface{}) bool {
|
||||
a, ok := right.([]interface{})
|
||||
func in(left, right any) bool {
|
||||
a, ok := right.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -439,23 +457,25 @@ func in(left, right interface{}) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func like(left, right interface{}) (bool, error) {
|
||||
func like(left, right any) (bool, error) {
|
||||
return match(right.(string), left.(string))
|
||||
}
|
||||
|
||||
func regexMatch(left, right interface{}) (bool, error) {
|
||||
func regexMatch(left, right any) (bool, error) {
|
||||
return regexp.Match(right.(string), []byte(left.(string)))
|
||||
}
|
||||
|
||||
func regexNonMatch(left, right interface{}) (bool, error) {
|
||||
func regexNonMatch(left, right any) (bool, error) {
|
||||
m, err := regexp.Match(right.(string), []byte(left.(string)))
|
||||
|
||||
return !m, err
|
||||
}
|
||||
|
||||
func typeValue(v interface{}) int {
|
||||
func typeValue(v any) int {
|
||||
switch v.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
@@ -465,9 +485,9 @@ func typeValue(v interface{}) int {
|
||||
return 2
|
||||
case string:
|
||||
return 3
|
||||
case []interface{}:
|
||||
case []any:
|
||||
return 4
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
return 5
|
||||
default:
|
||||
panic("unknown type")
|
||||
@@ -476,22 +496,25 @@ func typeValue(v interface{}) int {
|
||||
|
||||
// Ternary operator https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
|
||||
|
||||
func ternary(left, middle, right interface{}) interface{} {
|
||||
func ternary(left, middle, right any) any {
|
||||
if toBool(left) {
|
||||
if middle != nil {
|
||||
return middle
|
||||
}
|
||||
|
||||
return left
|
||||
}
|
||||
|
||||
return right
|
||||
}
|
||||
|
||||
// Range operators https://www.arangodb.com/docs/3.7/aql/operators.html#range-operator
|
||||
|
||||
func aqlrange(left, right interface{}) []float64 {
|
||||
func aqlrange(left, right any) []float64 {
|
||||
var v []float64
|
||||
for i := int(left.(float64)); i <= int(right.(float64)); i++ {
|
||||
v = append(v, float64(i))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func (p *Parser) Parse(aql string) (t *Tree, err error) {
|
||||
err = fmt.Errorf("%s", r)
|
||||
}
|
||||
}()
|
||||
// Setup the input
|
||||
// Set up the input
|
||||
inputStream := antlr.NewInputStream(aql)
|
||||
|
||||
errorListener := &errorListener{}
|
||||
@@ -52,7 +52,7 @@ type Tree struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
|
||||
func (t *Tree) Eval(values map[string]any) (i any, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("%s", r)
|
||||
@@ -65,6 +65,7 @@ func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
|
||||
if interpreter.errs != nil {
|
||||
return nil, interpreter.errs[0]
|
||||
}
|
||||
|
||||
return interpreter.stack[0], nil
|
||||
}
|
||||
|
||||
@@ -103,7 +104,7 @@ type errorListener struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
|
||||
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
|
||||
el.errs = append(el.errs, fmt.Errorf("line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package caql
|
||||
package caql_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/caql"
|
||||
)
|
||||
|
||||
type MockSearcher struct{}
|
||||
@@ -13,11 +15,13 @@ func (m MockSearcher) Search(_ string) (ids []string, err error) {
|
||||
}
|
||||
|
||||
func TestParseSAQLEval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saql string
|
||||
wantRebuild string
|
||||
wantValue interface{}
|
||||
wantValue any
|
||||
wantParseErr bool
|
||||
wantRebuildErr bool
|
||||
wantEvalErr bool
|
||||
@@ -89,15 +93,15 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
// {name: "String 9", saql: `'this is a longer string.'`, wantRebuild: `"this is a longer string."`, wantValue: "this is a longer string."},
|
||||
// {name: "String 10", saql: `'the path separator on Windows is \\'`, wantRebuild: `"the path separator on Windows is \\"`, wantValue: `the path separator on Windows is \`},
|
||||
|
||||
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []interface{}{}},
|
||||
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []interface{}{true}},
|
||||
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}},
|
||||
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []any{}},
|
||||
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []any{true}},
|
||||
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
|
||||
{
|
||||
name: "Array 4", saql: `[-99, "yikes!", [false, ["no"], []], 1]`, wantRebuild: `[-99, "yikes!", [false, ["no"], []], 1]`,
|
||||
wantValue: []interface{}{-99.0, "yikes!", []interface{}{false, []interface{}{"no"}, []interface{}{}}, float64(1)},
|
||||
wantValue: []any{-99.0, "yikes!", []any{false, []any{"no"}, []any{}}, float64(1)},
|
||||
},
|
||||
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []interface{}{[]interface{}{"fox", "marshal"}}},
|
||||
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}},
|
||||
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []any{[]any{"fox", "marshal"}}},
|
||||
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
|
||||
|
||||
{name: "Array Error 1", saql: "(1,2,3)", wantParseErr: true},
|
||||
{name: "Array Access 1", saql: "u.friends[0]", wantRebuild: "u.friends[0]", wantValue: 7, values: `{"u": {"friends": [7,8,9]}}`},
|
||||
@@ -105,14 +109,14 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
{name: "Array Access 3", saql: "u.friends[-1]", wantRebuild: "u.friends[-1]", wantValue: 9, values: `{"u": {"friends": [7,8,9]}}`},
|
||||
{name: "Array Access 4", saql: "u.friends[-2]", wantRebuild: "u.friends[-2]", wantValue: 8, values: `{"u": {"friends": [7,8,9]}}`},
|
||||
|
||||
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]interface{}{}},
|
||||
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]interface{}{"return": float64(1)}},
|
||||
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]interface{}{"return": float64(1)}},
|
||||
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
|
||||
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]any{}},
|
||||
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]any{"return": float64(1)}},
|
||||
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]any{"return": float64(1)}},
|
||||
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
|
||||
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
|
||||
// {"Object 8", "{`return`: 1}", `{"return": 1}`, true},
|
||||
// {"Object 7", "{´return´: 1}", `{"return": 1}`, true},
|
||||
{name: "Object Error 1: return is a keyword", saql: `{like: 1}`, wantParseErr: true},
|
||||
@@ -272,7 +276,7 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
{name: "Arithmetic 17", saql: `23 * {}`, wantRebuild: `23 * {}`, wantValue: 0},
|
||||
{name: "Arithmetic 18", saql: `5 * [7]`, wantRebuild: `5 * [7]`, wantValue: 35},
|
||||
{name: "Arithmetic 19", saql: `24 / "12"`, wantRebuild: `24 / "12"`, wantValue: 2},
|
||||
{name: "Arithmetic Error 1: Divison by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
|
||||
{name: "Arithmetic Error 1: Division by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
|
||||
|
||||
// https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
|
||||
{name: "Ternary 1", saql: `u.age > 15 || u.active == true ? u.userId : null`, wantRebuild: `u.age > 15 OR u.active == true ? u.userId : null`, wantValue: 45, values: `{"u": {"active": true, "age": 2, "userId": 45}}`},
|
||||
@@ -287,20 +291,24 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
{name: "Security 2", saql: `doc.value == 1 || true INSERT {foo: "bar"} IN collection //`, wantParseErr: true},
|
||||
|
||||
// https://www.arangodb.com/docs/3.7/aql/operators.html#operator-precedence
|
||||
{name: "Precendence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
|
||||
{name: "Precedence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
parser := &Parser{
|
||||
tt := tt
|
||||
parser := &caql.Parser{
|
||||
Searcher: &MockSearcher{},
|
||||
}
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expr, err := parser.Parse(tt.saql)
|
||||
if (err != nil) != tt.wantParseErr {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
if expr != nil {
|
||||
t.Error(expr.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -311,6 +319,7 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
if (err != nil) != tt.wantRebuildErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -320,18 +329,19 @@ func TestParseSAQLEval(t *testing.T) {
|
||||
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
|
||||
}
|
||||
|
||||
var myJson map[string]interface{}
|
||||
var myJSON map[string]any
|
||||
if tt.values != "" {
|
||||
err = json.Unmarshal([]byte(tt.values), &myJson)
|
||||
err = json.Unmarshal([]byte(tt.values), &myJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
value, err := expr.Eval(myJson)
|
||||
value, err := expr.Eval(myJSON)
|
||||
if (err != nil) != tt.wantEvalErr {
|
||||
t.Error(expr.String())
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
57
caql/set.go
57
caql/set.go
@@ -22,19 +22,18 @@
|
||||
|
||||
package caql
|
||||
|
||||
import "sort"
|
||||
|
||||
type (
|
||||
Set struct {
|
||||
hash map[interface{}]nothing
|
||||
}
|
||||
|
||||
nothing struct{}
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Create a new set
|
||||
func New(initial ...interface{}) *Set {
|
||||
s := &Set{make(map[interface{}]nothing)}
|
||||
type Set struct {
|
||||
hash map[any]nothing
|
||||
}
|
||||
|
||||
type nothing struct{}
|
||||
|
||||
func NewSet(initial ...any) *Set {
|
||||
s := &Set{make(map[any]nothing)}
|
||||
|
||||
for _, v := range initial {
|
||||
s.Insert(v)
|
||||
@@ -43,9 +42,8 @@ func New(initial ...interface{}) *Set {
|
||||
return s
|
||||
}
|
||||
|
||||
// Find the difference between two sets
|
||||
func (s *Set) Difference(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
|
||||
for k := range s.hash {
|
||||
if _, exists := set.hash[k]; !exists {
|
||||
@@ -56,27 +54,18 @@ func (s *Set) Difference(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
// Call f for each item in the set
|
||||
func (s *Set) Do(f func(interface{})) {
|
||||
for k := range s.hash {
|
||||
f(k)
|
||||
}
|
||||
}
|
||||
|
||||
// Test to see whether or not the element is in the set
|
||||
func (s *Set) Has(element interface{}) bool {
|
||||
func (s *Set) Has(element any) bool {
|
||||
_, exists := s.hash[element]
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// Add an element to the set
|
||||
func (s *Set) Insert(element interface{}) {
|
||||
func (s *Set) Insert(element any) {
|
||||
s.hash[element] = nothing{}
|
||||
}
|
||||
|
||||
// Find the intersection of two sets
|
||||
func (s *Set) Intersection(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
|
||||
for k := range s.hash {
|
||||
if _, exists := set.hash[k]; exists {
|
||||
@@ -87,23 +76,20 @@ func (s *Set) Intersection(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
// Return the number of items in the set
|
||||
func (s *Set) Len() int {
|
||||
return len(s.hash)
|
||||
}
|
||||
|
||||
// Test whether or not this set is a proper subset of "set"
|
||||
func (s *Set) ProperSubsetOf(set *Set) bool {
|
||||
return s.SubsetOf(set) && s.Len() < set.Len()
|
||||
}
|
||||
|
||||
// Remove an element from the set
|
||||
func (s *Set) Remove(element interface{}) {
|
||||
func (s *Set) Remove(element any) {
|
||||
delete(s.hash, element)
|
||||
}
|
||||
|
||||
func (s *Set) Minus(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
for k := range s.hash {
|
||||
n[k] = nothing{}
|
||||
}
|
||||
@@ -115,7 +101,6 @@ func (s *Set) Minus(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
// Test whether or not this set is a subset of "set"
|
||||
func (s *Set) SubsetOf(set *Set) bool {
|
||||
if s.Len() > set.Len() {
|
||||
return false
|
||||
@@ -125,12 +110,12 @@ func (s *Set) SubsetOf(set *Set) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Find the union of two sets
|
||||
func (s *Set) Union(set *Set) *Set {
|
||||
n := make(map[interface{}]nothing)
|
||||
n := make(map[any]nothing)
|
||||
|
||||
for k := range s.hash {
|
||||
n[k] = nothing{}
|
||||
@@ -142,8 +127,8 @@ func (s *Set) Union(set *Set) *Set {
|
||||
return &Set{n}
|
||||
}
|
||||
|
||||
func (s *Set) Values() []interface{} {
|
||||
values := []interface{}{}
|
||||
func (s *Set) Values() []any {
|
||||
values := []any{}
|
||||
|
||||
for k := range s.hash {
|
||||
values = append(values, k)
|
||||
|
||||
@@ -27,7 +27,9 @@ import (
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
s := New()
|
||||
t.Parallel()
|
||||
|
||||
s := NewSet()
|
||||
|
||||
s.Insert(5)
|
||||
|
||||
@@ -50,8 +52,8 @@ func Test(t *testing.T) {
|
||||
}
|
||||
|
||||
// Difference
|
||||
s1 := New(1, 2, 3, 4, 5, 6)
|
||||
s2 := New(4, 5, 6)
|
||||
s1 := NewSet(1, 2, 3, 4, 5, 6)
|
||||
s2 := NewSet(4, 5, 6)
|
||||
s3 := s1.Difference(s2)
|
||||
|
||||
if s3.Len() != 3 {
|
||||
@@ -73,7 +75,7 @@ func Test(t *testing.T) {
|
||||
}
|
||||
|
||||
// Union
|
||||
s4 := New(7, 8, 9)
|
||||
s4 := NewSet(7, 8, 9)
|
||||
s3 = s2.Union(s4)
|
||||
|
||||
if s3.Len() != 6 {
|
||||
@@ -92,5 +94,4 @@ func Test(t *testing.T) {
|
||||
if s1.ProperSubsetOf(s1) {
|
||||
t.Errorf("set should not be a subset of itself")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -39,8 +39,10 @@ func unquote(s string) (string, error) {
|
||||
buf = append(buf, s[i])
|
||||
}
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
if quote != '"' && quote != '\'' {
|
||||
@@ -75,5 +77,6 @@ func unquote(s string) (string, error) {
|
||||
buf = append(buf, runeTmp[:n]...)
|
||||
}
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
@@ -8,26 +8,25 @@
|
||||
package caql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type quoteTest struct {
|
||||
in string
|
||||
out string
|
||||
ascii string
|
||||
graphic string
|
||||
in string
|
||||
out string
|
||||
}
|
||||
|
||||
var quotetests = []quoteTest{
|
||||
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`, ascii: `"\a\b\f\r\n\t\v"`, graphic: `"\a\b\f\r\n\t\v"`},
|
||||
{"\\", `"\\"`, `"\\"`, `"\\"`},
|
||||
{"abc\xffdef", `"abc\xffdef"`, `"abc\xffdef"`, `"abc\xffdef"`},
|
||||
{"\u263a", `"☺"`, `"\u263a"`, `"☺"`},
|
||||
{"\U0010ffff", `"\U0010ffff"`, `"\U0010ffff"`, `"\U0010ffff"`},
|
||||
{"\x04", `"\x04"`, `"\x04"`, `"\x04"`},
|
||||
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`},
|
||||
{"\\", `"\\"`},
|
||||
{"abc\xffdef", `"abc\xffdef"`},
|
||||
{"\u263a", `"☺"`},
|
||||
{"\U0010ffff", `"\U0010ffff"`},
|
||||
{"\x04", `"\x04"`},
|
||||
// Some non-printable but graphic runes. Final column is double-quoted.
|
||||
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`, `"!\u00a0!\u2000!\u3000!"`, "\"!\u00a0!\u2000!\u3000!\""},
|
||||
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`},
|
||||
}
|
||||
|
||||
type unQuoteTest struct {
|
||||
@@ -104,6 +103,8 @@ var misquoted = []string{
|
||||
}
|
||||
|
||||
func TestUnquote(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range unquotetests {
|
||||
if out, err := unquote(tt.in); err != nil || out != tt.out {
|
||||
t.Errorf("unquote(%#q) = %q, %v want %q, nil", tt.in, out, err, tt.out)
|
||||
@@ -118,7 +119,7 @@ func TestUnquote(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, s := range misquoted {
|
||||
if out, err := unquote(s); out != "" || err != strconv.ErrSyntax {
|
||||
if out, err := unquote(s); out != "" || !errors.Is(err, strconv.ErrSyntax) {
|
||||
t.Errorf("unquote(%#q) = %q, %v want %q, %v", s, out, err, "", strconv.ErrSyntax)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ Pattern:
|
||||
// using the star
|
||||
if ok && (len(t) == 0 || len(pattern) > 0) {
|
||||
name = t
|
||||
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
@@ -64,6 +65,7 @@ Pattern:
|
||||
continue
|
||||
}
|
||||
name = t
|
||||
|
||||
continue Pattern
|
||||
}
|
||||
if err != nil {
|
||||
@@ -79,8 +81,10 @@ Pattern:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return len(name) == 0, nil
|
||||
}
|
||||
|
||||
@@ -104,6 +108,7 @@ Scan:
|
||||
break Scan
|
||||
}
|
||||
}
|
||||
|
||||
return star, pattern[0:i], pattern[i:]
|
||||
}
|
||||
|
||||
@@ -120,7 +125,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
|
||||
failed = true
|
||||
}
|
||||
switch chunk[0] {
|
||||
|
||||
case '_':
|
||||
if !failed {
|
||||
if s[0] == '/' {
|
||||
@@ -130,14 +134,13 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
|
||||
s = s[n:]
|
||||
}
|
||||
chunk = chunk[1:]
|
||||
|
||||
case '\\':
|
||||
chunk = chunk[1:]
|
||||
if len(chunk) == 0 {
|
||||
return "", false, ErrBadPattern
|
||||
}
|
||||
fallthrough
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
if !failed {
|
||||
if chunk[0] != s[0] {
|
||||
@@ -151,5 +154,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
|
||||
if failed {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
return s, true, nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
|
||||
package caql
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MatchTest struct {
|
||||
pattern, s string
|
||||
@@ -41,9 +44,11 @@ var matchTests = []MatchTest{
|
||||
}
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range matchTests {
|
||||
ok, err := match(tt.pattern, tt.s)
|
||||
if ok != tt.match || err != tt.err {
|
||||
if ok != tt.match || !errors.Is(err, tt.err) {
|
||||
t.Errorf("match(%#q, %#q) = %v, %v want %v, %v", tt.pattern, tt.s, ok, err, tt.match, tt.err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user