Migrate to Go 1.18 (#45)

* Migrate to Go 1.18 and add linters
This commit is contained in:
Jonas Plum
2022-03-20 03:17:18 +01:00
committed by GitHub
parent 03a4806d45
commit 2bad1f5f28
88 changed files with 1430 additions and 868 deletions

View File

@@ -8,7 +8,7 @@ import (
"github.com/SecurityBrewery/catalyst/generated/caql/parser"
)
var TooComplexError = errors.New("unsupported features for index queries, use advanced search instead")
var ErrTooComplex = errors.New("unsupported features for index queries, use advanced search instead")
type bleveBuilder struct {
*parser.BaseCAQLParserListener
@@ -35,8 +35,9 @@ func (s *bleveBuilder) pop() (n string) {
return
}
func (s *bleveBuilder) binaryPop() (interface{}, interface{}) {
func (s *bleveBuilder) binaryPop() (any, any) {
right, left := s.pop(), s.pop()
return left, right
}
@@ -48,9 +49,7 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
case ctx.Reference() != nil:
// pass
case ctx.Operator_unary() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_PLUS() != nil:
fallthrough
case ctx.T_MINUS() != nil:
@@ -60,13 +59,9 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
case ctx.T_DIV() != nil:
fallthrough
case ctx.T_MOD() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_RANGE() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop()
s.push(fmt.Sprintf("%s:<%s", left, right))
@@ -79,64 +74,46 @@ func (s *bleveBuilder) ExitExpression(ctx *parser.ExpressionContext) {
case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop()
s.push(fmt.Sprintf("%s:>=%s", left, right))
case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop()
s.push(fmt.Sprintf("%s:%s", left, right))
case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
left, right := s.binaryPop()
s.push(fmt.Sprintf("-%s:%s", left, right))
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
fallthrough
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
fallthrough
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
fallthrough
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
fallthrough
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_LIKE() != nil:
s.err = errors.New("index queries are like queries by default")
return
case ctx.T_REGEX_MATCH() != nil:
left, right := s.binaryPop()
if ctx.T_NOT() != nil {
s.err = TooComplexError
return
s.err = ErrTooComplex
} else {
s.push(fmt.Sprintf("%s:/%s/", left, right))
}
case ctx.T_REGEX_NON_MATCH() != nil:
s.err = errors.New("index query cannot contain regex non matches, use advanced search instead")
return
case ctx.T_AND() != nil:
left, right := s.binaryPop()
s.push(fmt.Sprintf("%s %s", left, right))
case ctx.T_OR() != nil:
s.err = errors.New("index query cannot contain OR, use advanced search instead")
return
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
return
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
s.err = errors.New("index query cannot contain ternary operations, use advanced search instead")
return
default:
panic("unknown expression")
}
@@ -152,17 +129,13 @@ func (s *bleveBuilder) ExitReference(ctx *parser.ReferenceContext) {
case ctx.T_STRING() != nil:
s.push(ctx.T_STRING().GetText())
case ctx.Compound_value() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.Function_call() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_OPEN() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
case ctx.T_ARRAY_OPEN() != nil:
s.err = TooComplexError
return
s.err = ErrTooComplex
default:
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
}

View File

@@ -1,10 +1,14 @@
package caql
package caql_test
import (
"testing"
"github.com/SecurityBrewery/catalyst/caql"
)
func TestBleveBuilder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
saql string
@@ -18,15 +22,20 @@ func TestBleveBuilder(t *testing.T) {
{name: "Search 4", saql: `title == 'malware' AND 'wannacry'`, wantBleve: `title:"malware" "wannacry"`},
}
for _, tt := range tests {
parser := &Parser{}
tt := tt
parser := &caql.Parser{}
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
expr, err := parser.Parse(tt.saql)
if (err != nil) != tt.wantParseErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
if expr != nil {
t.Error(expr.String())
}
return
}
if err != nil {
@@ -37,6 +46,7 @@ func TestBleveBuilder(t *testing.T) {
if (err != nil) != tt.wantRebuildErr {
t.Error(expr.String())
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
return
}
if err != nil {

View File

@@ -5,6 +5,8 @@ import (
"strconv"
"strings"
"golang.org/x/exp/slices"
"github.com/SecurityBrewery/catalyst/generated/caql/parser"
)
@@ -40,6 +42,7 @@ func (s *aqlBuilder) pop() (n string) {
func (s *aqlBuilder) binaryPop() (string, string) {
right, left := s.pop(), s.pop()
return left, right
}
@@ -181,8 +184,10 @@ func (s *aqlBuilder) toBoolString(v string) string {
if err != nil {
panic("invalid search " + err.Error())
}
return fmt.Sprintf(`d._key IN ["%s"]`, strings.Join(ids, `","`))
}
return v
}
@@ -246,7 +251,7 @@ func (s *aqlBuilder) ExitFunction_call(ctx *parser.Function_callContext) {
}
parameter := strings.Join(array, ", ")
if !stringSliceContains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
if !slices.Contains(functionNames, strings.ToUpper(ctx.T_STRING().GetText())) {
panic("unknown function")
}

View File

@@ -16,7 +16,6 @@ import (
func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
switch strings.ToUpper(ctx.T_STRING().GetText()) {
default:
s.appendErrors(errors.New("unknown function"))
@@ -26,8 +25,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
if len(ctx.AllExpression()) == 3 {
u = s.pop().(bool)
}
seen := map[interface{}]bool{}
values, anyArray := s.pop().([]interface{}), s.pop().([]interface{})
seen := map[any]bool{}
values, anyArray := s.pop().([]any), s.pop().([]any)
if u {
for _, e := range anyArray {
@@ -45,18 +44,18 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(anyArray)
case "COUNT_DISTINCT", "COUNT_UNIQUE":
count := 0
seen := map[interface{}]bool{}
array := s.pop().([]interface{})
seen := map[any]bool{}
array := s.pop().([]any)
for _, e := range array {
_, ok := seen[e]
if !ok {
seen[e] = true
count += 1
count++
}
}
s.push(float64(count))
case "FIRST":
array := s.pop().([]interface{})
array := s.pop().([]any)
if len(array) == 0 {
s.push(nil)
} else {
@@ -65,16 +64,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "FLATTEN":
// case "INTERLEAVE":
case "INTERSECTION":
iset := New(s.pop().([]interface{})...)
iset := NewSet(s.pop().([]any)...)
for i := 1; i < len(ctx.AllExpression()); i++ {
iset = iset.Intersection(New(s.pop().([]interface{})...))
iset = iset.Intersection(NewSet(s.pop().([]any)...))
}
s.push(iset.Values())
// case "JACCARD":
case "LAST":
array := s.pop().([]interface{})
array := s.pop().([]any)
if len(array) == 0 {
s.push(nil)
} else {
@@ -94,9 +93,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(float64(len(fmt.Sprint(v))))
case string:
s.push(float64(utf8.RuneCountInString(v)))
case []interface{}:
case []any:
s.push(float64(len(v)))
case map[string]interface{}:
case map[string]any:
s.push(float64(len(v)))
default:
panic("unknown type")
@@ -104,7 +103,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "MINUS":
var sets []*Set
for i := 0; i < len(ctx.AllExpression()); i++ {
sets = append(sets, New(s.pop().([]interface{})...))
sets = append(sets, NewSet(s.pop().([]any)...))
}
iset := sets[len(sets)-1]
@@ -116,7 +115,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(iset.Values())
case "NTH":
pos := s.pop().(float64)
array := s.pop().([]interface{})
array := s.pop().([]any)
if int(pos) >= len(array) || pos < 0 {
s.push(nil)
} else {
@@ -124,16 +123,16 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
// case "OUTERSECTION":
// array := s.pop().([]interface{})
// union := New(array...)
// intersection := New(s.pop().([]interface{})...)
// union := NewSet(array...)
// intersection := NewSet(s.pop().([]interface{})...)
// for i := 1; i < len(ctx.AllExpression()); i++ {
// array = s.pop().([]interface{})
// union = union.Union(New(array...))
// intersection = intersection.Intersection(New(array...))
// union = union.Union(NewSet(array...))
// intersection = intersection.Intersection(NewSet(array...))
// }
// s.push(union.Minus(intersection).Values())
case "POP":
array := s.pop().([]interface{})
array := s.pop().([]any)
s.push(array[:len(array)-1])
case "POSITION", "CONTAINS_ARRAY":
returnIndex := false
@@ -141,7 +140,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
returnIndex = s.pop().(bool)
}
search := s.pop()
array := s.pop().([]interface{})
array := s.pop().([]any)
for idx, e := range array {
if e == search {
@@ -164,7 +163,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
u = s.pop().(bool)
}
element := s.pop()
array := s.pop().([]interface{})
array := s.pop().([]any)
if u && contains(array, element) {
s.push(array)
@@ -173,13 +172,13 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
case "REMOVE_NTH":
position := s.pop().(float64)
anyArray := s.pop().([]interface{})
anyArray := s.pop().([]any)
if position < 0 {
position = float64(len(anyArray) + int(position))
}
result := []interface{}{}
result := []any{}
for idx, e := range anyArray {
if idx != int(position) {
result = append(result, e)
@@ -193,7 +192,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
replaceValue := s.pop().(string)
position := s.pop().(float64)
anyArray := s.pop().([]interface{})
anyArray := s.pop().([]any)
if position < 0 {
position = float64(len(anyArray) + int(position))
@@ -224,8 +223,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
limit = s.pop().(float64)
}
value := s.pop()
array := s.pop().([]interface{})
result := []interface{}{}
array := s.pop().([]any)
result := []any{}
for idx, e := range array {
if e != value || float64(idx) > limit {
result = append(result, e)
@@ -233,9 +232,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
s.push(result)
case "REMOVE_VALUES":
values := s.pop().([]interface{})
array := s.pop().([]interface{})
result := []interface{}{}
values := s.pop().([]any)
array := s.pop().([]any)
result := []any{}
for _, e := range array {
if !contains(values, e) {
result = append(result, e)
@@ -243,14 +242,14 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
s.push(result)
case "REVERSE":
array := s.pop().([]interface{})
var reverse []interface{}
array := s.pop().([]any)
var reverse []any
for _, e := range array {
reverse = append([]interface{}{e}, reverse...)
reverse = append([]any{e}, reverse...)
}
s.push(reverse)
case "SHIFT":
s.push(s.pop().([]interface{})[1:])
s.push(s.pop().([]any)[1:])
case "SLICE":
length := float64(-1)
full := true
@@ -259,7 +258,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
full = false
}
start := int64(s.pop().(float64))
array := s.pop().([]interface{})
array := s.pop().([]any)
if start < 0 {
start = int64(len(array)) + start
@@ -276,43 +275,43 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
s.push(array[start:end])
case "SORTED":
array := s.pop().([]interface{})
array := s.pop().([]any)
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
s.push(array)
case "SORTED_UNIQUE":
array := s.pop().([]interface{})
array := s.pop().([]any)
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
s.push(unique(array))
case "UNION":
array := s.pop().([]interface{})
array := s.pop().([]any)
for i := 1; i < len(ctx.AllExpression()); i++ {
array = append(array, s.pop().([]interface{})...)
array = append(array, s.pop().([]any)...)
}
sort.Slice(array, func(i, j int) bool { return lt(array[i], array[j]) })
s.push(array)
case "UNION_DISTINCT":
iset := New(s.pop().([]interface{})...)
iset := NewSet(s.pop().([]any)...)
for i := 1; i < len(ctx.AllExpression()); i++ {
iset = iset.Union(New(s.pop().([]interface{})...))
iset = iset.Union(NewSet(s.pop().([]any)...))
}
s.push(unique(iset.Values()))
case "UNIQUE":
s.push(unique(s.pop().([]interface{})))
s.push(unique(s.pop().([]any)))
case "UNSHIFT":
u := false
if len(ctx.AllExpression()) == 3 {
u = s.pop().(bool)
}
element := s.pop()
array := s.pop().([]interface{})
array := s.pop().([]any)
if u && contains(array, element) {
s.push(array)
} else {
s.push(append([]interface{}{element}, array...))
s.push(append([]any{element}, array...))
}
// Bit https://www.arangodb.com/docs/stable/aql/functions-bit.html
@@ -367,8 +366,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
if len(ctx.AllExpression()) >= 2 {
removeInternal = s.pop().(bool)
}
var keys []interface{}
for k := range s.pop().(map[string]interface{}) {
var keys []any
for k := range s.pop().(map[string]any) {
isInternalKey := strings.HasPrefix(k, "_")
if !removeInternal || !isInternalKey {
keys = append(keys, k)
@@ -379,20 +378,20 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "COUNT":
case "HAS":
right, left := s.pop(), s.pop()
_, ok := left.(map[string]interface{})[right.(string)]
_, ok := left.(map[string]any)[right.(string)]
s.push(ok)
// case "KEEP":
// case "LENGTH":
// case "MATCHES":
case "MERGE":
var docs []map[string]interface{}
var docs []map[string]any
if len(ctx.AllExpression()) == 1 {
for _, doc := range s.pop().([]interface{}) {
docs = append([]map[string]interface{}{doc.(map[string]interface{})}, docs...)
for _, doc := range s.pop().([]any) {
docs = append([]map[string]any{doc.(map[string]any)}, docs...)
}
} else {
for i := 0; i < len(ctx.AllExpression()); i++ {
docs = append(docs, s.pop().(map[string]interface{}))
docs = append(docs, s.pop().(map[string]any))
}
}
@@ -404,9 +403,9 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
}
s.push(doc)
case "MERGE_RECURSIVE":
var doc map[string]interface{}
var doc map[string]any
for i := 0; i < len(ctx.AllExpression()); i++ {
err := mergo.Merge(&doc, s.pop().(map[string]interface{}))
err := mergo.Merge(&doc, s.pop().(map[string]any))
if err != nil {
panic(err)
}
@@ -421,8 +420,8 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
if len(ctx.AllExpression()) == 2 {
removeInternal = s.pop().(bool)
}
var values []interface{}
for k, v := range s.pop().(map[string]interface{}) {
var values []any
for k, v := range s.pop().(map[string]any) {
isInternalKey := strings.HasPrefix(k, "_")
if !removeInternal || !isInternalKey {
values = append(values, v)
@@ -458,10 +457,10 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "AVERAGE", "AVG":
count := 0
sum := float64(0)
array := s.pop().([]interface{})
array := s.pop().([]any)
for _, element := range array {
if element != nil {
count += 1
count++
sum += toNumber(element)
}
}
@@ -506,7 +505,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "MAX":
var set bool
var max float64
array := s.pop().([]interface{})
array := s.pop().([]any)
for _, element := range array {
if element != nil {
if !set || toNumber(element) > max {
@@ -521,7 +520,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(nil)
}
case "MEDIAN":
array := s.pop().([]interface{})
array := s.pop().([]any)
var numbers []float64
for _, element := range array {
if f, ok := element.(float64); ok {
@@ -544,7 +543,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "MIN":
var set bool
var min float64
array := s.pop().([]interface{})
array := s.pop().([]any)
for _, element := range array {
if element != nil {
if !set || toNumber(element) < min {
@@ -566,7 +565,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
s.push(math.Pow(left.(float64), right.(float64)))
case "PRODUCT":
product := float64(1)
array := s.pop().([]interface{})
array := s.pop().([]any)
for _, element := range array {
if element != nil {
product *= toNumber(element)
@@ -578,7 +577,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
case "RAND":
s.push(rand.Float64())
case "RANGE":
var array []interface{}
var array []any
var start, end, step float64
if len(ctx.AllExpression()) == 2 {
right, left := s.pop(), s.pop()
@@ -612,7 +611,7 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "STDDEV":
case "SUM":
sum := float64(0)
array := s.pop().([]interface{})
array := s.pop().([]any)
for _, element := range array {
sum += toNumber(element)
}
@@ -691,7 +690,6 @@ func (s *aqlInterpreter) function(ctx *parser.Function_callContext) {
// case "IS_IPV4":
// case "IS_KEY":
// case "TYPENAME":
}
}
@@ -705,6 +703,7 @@ func unique(array []interface{}) []interface{} {
filtered = append(filtered, e)
}
}
return filtered
}
@@ -714,15 +713,7 @@ func contains(values []interface{}, e interface{}) bool {
return true
}
}
return false
}
func stringSliceContains(values []string, e string) bool {
for _, v := range values {
if e == v {
return true
}
}
return false
}
@@ -747,4 +738,5 @@ var functionNames = []string{
"REGEX_REPLACE", "REVERSE", "RIGHT", "RTRIM", "SHA1", "SHA512", "SOUNDEX", "SPLIT", "STARTS_WITH", "SUBSTITUTE",
"SUBSTRING", "TOKENS", "TO_BASE64", "TO_HEX", "TRIM", "UPPER", "UUID", "TO_BOOL", "TO_NUMBER", "TO_STRING",
"TO_ARRAY", "TO_LIST", "IS_NULL", "IS_BOOL", "IS_NUMBER", "IS_STRING", "IS_ARRAY", "IS_LIST", "IS_OBJECT",
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME"}
"IS_DOCUMENT", "IS_DATESTRING", "IS_IPV4", "IS_KEY", "TYPENAME",
}

View File

@@ -1,18 +1,22 @@
package caql
package caql_test
import (
"encoding/json"
"math"
"reflect"
"testing"
"github.com/SecurityBrewery/catalyst/caql"
)
func TestFunctions(t *testing.T) {
t.Parallel()
tests := []struct {
name string
saql string
wantRebuild string
wantValue interface{}
wantValue any
wantParseErr bool
wantRebuildErr bool
wantEvalErr bool
@@ -266,13 +270,13 @@ func TestFunctions(t *testing.T) {
{name: "RADIANS", saql: `RADIANS(0)`, wantRebuild: `RADIANS(0)`, wantValue: 0},
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.3503170117504508},
// {name: "RAND", saql: `RAND()`, wantRebuild: `RAND()`, wantValue: 0.6138226173882478},
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []interface{}{float64(1), float64(2), float64(3), float64(4)}},
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []interface{}{float64(1), float64(3)}},
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []interface{}{float64(1), float64(4)}},
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []interface{}{float64(1), float64(2)}},
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []interface{}{1.5, 2.5}},
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []interface{}{1.5, 2.0, 2.5}},
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []interface{}{-0.75, -0.25, 0.25, 0.75}},
{name: "RANGE", saql: `RANGE(1, 4)`, wantRebuild: `RANGE(1, 4)`, wantValue: []any{float64(1), float64(2), float64(3), float64(4)}},
{name: "RANGE", saql: `RANGE(1, 4, 2)`, wantRebuild: `RANGE(1, 4, 2)`, wantValue: []any{float64(1), float64(3)}},
{name: "RANGE", saql: `RANGE(1, 4, 3)`, wantRebuild: `RANGE(1, 4, 3)`, wantValue: []any{float64(1), float64(4)}},
{name: "RANGE", saql: `RANGE(1.5, 2.5)`, wantRebuild: `RANGE(1.5, 2.5)`, wantValue: []any{float64(1), float64(2)}},
{name: "RANGE", saql: `RANGE(1.5, 2.5, 1)`, wantRebuild: `RANGE(1.5, 2.5, 1)`, wantValue: []any{1.5, 2.5}},
{name: "RANGE", saql: `RANGE(1.5, 2.5, 0.5)`, wantRebuild: `RANGE(1.5, 2.5, 0.5)`, wantValue: []any{1.5, 2.0, 2.5}},
{name: "RANGE", saql: `RANGE(-0.75, 1.1, 0.5)`, wantRebuild: `RANGE(-0.75, 1.1, 0.5)`, wantValue: []any{-0.75, -0.25, 0.25, 0.75}},
{name: "ROUND", saql: `ROUND(2.49)`, wantRebuild: `ROUND(2.49)`, wantValue: 2},
{name: "ROUND", saql: `ROUND(2.50)`, wantRebuild: `ROUND(2.50)`, wantValue: 3},
{name: "ROUND", saql: `ROUND(-2.50)`, wantRebuild: `ROUND(-2.50)`, wantValue: -2},
@@ -299,15 +303,20 @@ func TestFunctions(t *testing.T) {
{name: "Function Error 3", saql: `ABS("abs")`, wantRebuild: `ABS("abs")`, wantEvalErr: true},
}
for _, tt := range tests {
parser := &Parser{}
tt := tt
parser := &caql.Parser{}
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
expr, err := parser.Parse(tt.saql)
if (err != nil) != tt.wantParseErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
if expr != nil {
t.Error(expr.String())
}
return
}
if err != nil {
@@ -318,6 +327,7 @@ func TestFunctions(t *testing.T) {
if (err != nil) != tt.wantRebuildErr {
t.Error(expr.String())
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
return
}
if err != nil {
@@ -327,18 +337,19 @@ func TestFunctions(t *testing.T) {
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
}
var myJson map[string]interface{}
var myJSON map[string]any
if tt.values != "" {
err = json.Unmarshal([]byte(tt.values), &myJson)
err = json.Unmarshal([]byte(tt.values), &myJSON)
if err != nil {
t.Fatal(err)
}
}
value, err := expr.Eval(myJson)
value, err := expr.Eval(myJSON)
if (err != nil) != tt.wantEvalErr {
t.Error(expr.String())
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
return
}
if err != nil {
@@ -367,14 +378,15 @@ func TestFunctions(t *testing.T) {
}
}
func jsonParse(s string) interface{} {
func jsonParse(s string) any {
if s == "" {
return nil
}
var j interface{}
var j any
err := json.Unmarshal([]byte(s), &j)
if err != nil {
panic(s + err.Error())
}
return j
}

View File

@@ -10,22 +10,23 @@ import (
type aqlInterpreter struct {
*parser.BaseCAQLParserListener
values map[string]interface{}
stack []interface{}
values map[string]any
stack []any
errs []error
}
// push is a helper function for pushing new node to the listener Stack.
func (s *aqlInterpreter) push(i interface{}) {
func (s *aqlInterpreter) push(i any) {
s.stack = append(s.stack, i)
}
// pop is a helper function for poping a node from the listener Stack.
func (s *aqlInterpreter) pop() (n interface{}) {
func (s *aqlInterpreter) pop() (n any) {
// Check that we have nodes in the stack.
size := len(s.stack)
if size < 1 {
s.appendErrors(ErrStack)
return
}
@@ -35,8 +36,9 @@ func (s *aqlInterpreter) pop() (n interface{}) {
return
}
func (s *aqlInterpreter) binaryPop() (interface{}, interface{}) {
func (s *aqlInterpreter) binaryPop() (any, any) {
right, left := s.pop(), s.pop()
return left, right
}
@@ -54,17 +56,14 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
s.push(plus(s.binaryPop()))
case ctx.T_MINUS() != nil:
s.push(minus(s.binaryPop()))
case ctx.T_TIMES() != nil:
s.push(times(s.binaryPop()))
case ctx.T_DIV() != nil:
s.push(div(s.binaryPop()))
case ctx.T_MOD() != nil:
s.push(mod(s.binaryPop()))
case ctx.T_RANGE() != nil:
s.push(aqlrange(s.binaryPop()))
case ctx.T_LT() != nil && ctx.GetEq_op() == nil:
s.push(lt(s.binaryPop()))
case ctx.T_GT() != nil && ctx.GetEq_op() == nil:
@@ -73,35 +72,30 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
s.push(le(s.binaryPop()))
case ctx.T_GE() != nil && ctx.GetEq_op() == nil:
s.push(ge(s.binaryPop()))
case ctx.T_IN() != nil && ctx.GetEq_op() == nil:
s.push(maybeNot(ctx, in(s.binaryPop())))
case ctx.T_EQ() != nil && ctx.GetEq_op() == nil:
s.push(eq(s.binaryPop()))
case ctx.T_NE() != nil && ctx.GetEq_op() == nil:
s.push(ne(s.binaryPop()))
case ctx.T_ALL() != nil && ctx.GetEq_op() != nil:
right, left := s.pop(), s.pop()
s.push(all(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
s.push(all(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
case ctx.T_ANY() != nil && ctx.GetEq_op() != nil:
right, left := s.pop(), s.pop()
s.push(any(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
s.push(anyElement(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
case ctx.T_NONE() != nil && ctx.GetEq_op() != nil:
right, left := s.pop(), s.pop()
s.push(none(left.([]interface{}), getOp(ctx.GetEq_op().GetTokenType()), right))
s.push(none(left.([]any), getOp(ctx.GetEq_op().GetTokenType()), right))
case ctx.T_ALL() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
right, left := s.pop(), s.pop()
s.push(all(left.([]interface{}), in, right))
s.push(all(left.([]any), in, right))
case ctx.T_ANY() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
right, left := s.pop(), s.pop()
s.push(any(left.([]interface{}), in, right))
s.push(anyElement(left.([]any), in, right))
case ctx.T_NONE() != nil && ctx.T_NOT() != nil && ctx.T_IN() != nil:
right, left := s.pop(), s.pop()
s.push(none(left.([]interface{}), in, right))
s.push(none(left.([]any), in, right))
case ctx.T_LIKE() != nil:
m, err := like(s.binaryPop())
s.appendErrors(err)
@@ -114,21 +108,18 @@ func (s *aqlInterpreter) ExitExpression(ctx *parser.ExpressionContext) {
m, err := regexNonMatch(s.binaryPop())
s.appendErrors(err)
s.push(maybeNot(ctx, m))
case ctx.T_AND() != nil:
s.push(and(s.binaryPop()))
case ctx.T_OR() != nil:
s.push(or(s.binaryPop()))
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 3:
right, middle, left := s.pop(), s.pop(), s.pop()
s.push(ternary(left, middle, right))
case ctx.T_QUESTION() != nil && len(ctx.AllExpression()) == 2:
right, left := s.pop(), s.pop()
s.push(ternary(left, nil, right))
default:
panic("unkown expression")
panic("unknown expression")
}
}
@@ -159,7 +150,7 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
case ctx.DOT() != nil:
reference := s.pop()
s.push(reference.(map[string]interface{})[ctx.T_STRING().GetText()])
s.push(reference.(map[string]any)[ctx.T_STRING().GetText()])
case ctx.T_STRING() != nil:
s.push(s.getVar(ctx.T_STRING().GetText()))
case ctx.Compound_value() != nil:
@@ -175,14 +166,15 @@ func (s *aqlInterpreter) ExitReference(ctx *parser.ReferenceContext) {
if f, ok := key.(float64); ok {
index := int(f)
if index < 0 {
index = len(reference.([]interface{})) + index
index = len(reference.([]any)) + index
}
s.push(reference.([]interface{})[index])
s.push(reference.([]any)[index])
return
}
s.push(reference.(map[string]interface{})[key.(string)])
s.push(reference.(map[string]any)[key.(string)])
default:
panic(fmt.Sprintf("unexpected value: %s", ctx.GetText()))
}
@@ -239,17 +231,17 @@ func (s *aqlInterpreter) ExitValue_literal(ctx *parser.Value_literalContext) {
// ExitArray is called when production array is exited.
func (s *aqlInterpreter) ExitArray(ctx *parser.ArrayContext) {
array := []interface{}{}
array := []any{}
for range ctx.AllExpression() {
// prepend element
array = append([]interface{}{s.pop()}, array...)
array = append([]any{s.pop()}, array...)
}
s.push(array)
}
// ExitObject is called when production object is exited.
func (s *aqlInterpreter) ExitObject(ctx *parser.ObjectContext) {
object := map[string]interface{}{}
object := map[string]any{}
for range ctx.AllObject_element() {
key, value := s.pop(), s.pop()
@@ -290,7 +282,7 @@ func (s *aqlInterpreter) ExitObject_element_name(ctx *parser.Object_element_name
}
}
func (s *aqlInterpreter) getVar(identifier string) interface{} {
func (s *aqlInterpreter) getVar(identifier string) any {
v, ok := s.values[identifier]
if !ok {
s.appendErrors(ErrUndefined)
@@ -303,10 +295,11 @@ func maybeNot(ctx *parser.ExpressionContext, m bool) bool {
if ctx.T_NOT() != nil {
return !m
}
return m
}
func getOp(tokenType int) func(left, right interface{}) bool {
func getOp(tokenType int) func(left, right any) bool {
switch tokenType {
case parser.CAQLLexerT_EQ:
return eq
@@ -323,33 +316,36 @@ func getOp(tokenType int) func(left, right interface{}) bool {
case parser.CAQLLexerT_IN:
return in
default:
panic("unkown token type")
panic("unknown token type")
}
}
func all(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
func all(slice []any, op func(any, any) bool, expr any) bool {
for _, e := range slice {
if !op(e, expr) {
return false
}
}
return true
}
func any(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
func anyElement(slice []any, op func(any, any) bool, expr any) bool {
for _, e := range slice {
if op(e, expr) {
return true
}
}
return false
}
func none(slice []interface{}, op func(interface{}, interface{}) bool, expr interface{}) bool {
func none(slice []any, op func(any, any) bool, expr any) bool {
for _, e := range slice {
if op(e, expr) {
return false
}
}
return true
}

View File

@@ -10,21 +10,23 @@ import (
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
func or(left, right interface{}) interface{} {
func or(left, right any) any {
if toBool(left) {
return left
}
return right
}
func and(left, right interface{}) interface{} {
func and(left, right any) any {
if !toBool(left) {
return left
}
return right
}
func toBool(i interface{}) bool {
func toBool(i any) bool {
switch v := i.(type) {
case nil:
return false
@@ -36,9 +38,9 @@ func toBool(i interface{}) bool {
return v != 0
case string:
return v != ""
case []interface{}:
case []any:
return true
case map[string]interface{}:
case map[string]any:
return true
default:
panic("bool conversion failed")
@@ -47,15 +49,15 @@ func toBool(i interface{}) bool {
// Arithmetic operators https://www.arangodb.com/docs/3.7/aql/operators.html#arithmetic-operators
func plus(left, right interface{}) float64 {
func plus(left, right any) float64 {
return toNumber(left) + toNumber(right)
}
func minus(left, right interface{}) float64 {
func minus(left, right any) float64 {
return toNumber(left) - toNumber(right)
}
func times(left, right interface{}) float64 {
func times(left, right any) float64 {
return round(toNumber(left) * toNumber(right))
}
@@ -63,19 +65,20 @@ func round(r float64) float64 {
return math.Round(r*100000) / 100000
}
func div(left, right interface{}) float64 {
func div(left, right any) float64 {
b := toNumber(right)
if b == 0 {
return 0
}
return round(toNumber(left) / b)
}
func mod(left, right interface{}) float64 {
func mod(left, right any) float64 {
return math.Mod(toNumber(left), toNumber(right))
}
func toNumber(i interface{}) float64 {
func toNumber(i any) float64 {
switch v := i.(type) {
case nil:
return 0
@@ -83,6 +86,7 @@ func toNumber(i interface{}) float64 {
if v {
return 1
}
return 0
case float64:
switch {
@@ -91,22 +95,25 @@ func toNumber(i interface{}) float64 {
case math.IsInf(v, 0):
return 0
}
return v
case string:
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil {
return 0
}
return f
case []interface{}:
case []any:
if len(v) == 0 {
return 0
}
if len(v) == 1 {
return toNumber(v[0])
}
return 0
case map[string]interface{}:
case map[string]any:
return 0
default:
panic("number conversion error")
@@ -116,7 +123,7 @@ func toNumber(i interface{}) float64 {
// Logical operators https://www.arangodb.com/docs/3.7/aql/operators.html#logical-operators
// Order https://www.arangodb.com/docs/3.7/aql/fundamentals-type-value-order.html
func eq(left, right interface{}) bool {
func eq(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV {
return false
@@ -126,15 +133,15 @@ func eq(left, right interface{}) bool {
return true
case bool, float64, string:
return left == right
case []interface{}:
ra := right.([]interface{})
case []any:
ra := right.([]any)
max := len(l)
if len(ra) > max {
max = len(ra)
}
for i := 0; i < max; i++ {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if len(l) > i {
li = l[i]
}
@@ -146,13 +153,14 @@ func eq(left, right interface{}) bool {
return false
}
}
return true
case map[string]interface{}:
ro := right.(map[string]interface{})
case map[string]any:
ro := right.(map[string]any)
for _, key := range keys(l, ro) {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if lv, ok := l[key]; ok {
li = lv
}
@@ -164,17 +172,18 @@ func eq(left, right interface{}) bool {
return false
}
}
return true
default:
panic("unknown type")
}
}
func ne(left, right interface{}) bool {
func ne(left, right any) bool {
return !eq(left, right)
}
func lt(left, right interface{}) bool {
func lt(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV {
return leftV < rightV
@@ -190,15 +199,15 @@ func lt(left, right interface{}) bool {
return l < right.(float64)
case string:
return l < right.(string)
case []interface{}:
ra := right.([]interface{})
case []any:
ra := right.([]any)
max := len(l)
if len(ra) > max {
max = len(ra)
}
for i := 0; i < max; i++ {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if len(l) > i {
li = l[i]
}
@@ -210,13 +219,14 @@ func lt(left, right interface{}) bool {
return lt(li, rai)
}
}
return false
case map[string]interface{}:
ro := right.(map[string]interface{})
case map[string]any:
ro := right.(map[string]any)
for _, key := range keys(l, ro) {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if lv, ok := l[key]; ok {
li = lv
}
@@ -228,16 +238,17 @@ func lt(left, right interface{}) bool {
return lt(li, rai)
}
}
return false
default:
panic("unknown type")
}
}
func keys(l map[string]interface{}, ro map[string]interface{}) []string {
func keys(l map[string]any, ro map[string]any) []string {
var keys []string
seen := map[string]bool{}
for _, a := range []map[string]interface{}{l, ro} {
for _, a := range []map[string]any{l, ro} {
for k := range a {
if _, ok := seen[k]; !ok {
seen[k] = true
@@ -246,10 +257,11 @@ func keys(l map[string]interface{}, ro map[string]interface{}) []string {
}
}
sort.Strings(keys)
return keys
}
func gt(left, right interface{}) bool {
func gt(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV {
return leftV > rightV
@@ -265,15 +277,15 @@ func gt(left, right interface{}) bool {
return l > right.(float64)
case string:
return l > right.(string)
case []interface{}:
ra := right.([]interface{})
case []any:
ra := right.([]any)
max := len(l)
if len(ra) > max {
max = len(ra)
}
for i := 0; i < max; i++ {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if len(l) > i {
li = l[i]
}
@@ -285,13 +297,14 @@ func gt(left, right interface{}) bool {
return gt(li, rai)
}
}
return false
case map[string]interface{}:
ro := right.(map[string]interface{})
case map[string]any:
ro := right.(map[string]any)
for _, key := range keys(l, ro) {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if lv, ok := l[key]; ok {
li = lv
}
@@ -303,13 +316,14 @@ func gt(left, right interface{}) bool {
return gt(li, rai)
}
}
return false
default:
panic("unknown type")
}
}
func le(left, right interface{}) bool {
func le(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV {
return leftV <= rightV
@@ -325,15 +339,15 @@ func le(left, right interface{}) bool {
return l <= right.(float64)
case string:
return l <= right.(string)
case []interface{}:
ra := right.([]interface{})
case []any:
ra := right.([]any)
max := len(l)
if len(ra) > max {
max = len(ra)
}
for i := 0; i < max; i++ {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if len(l) > i {
li = l[i]
}
@@ -345,13 +359,14 @@ func le(left, right interface{}) bool {
return le(li, rai)
}
}
return true
case map[string]interface{}:
ro := right.(map[string]interface{})
case map[string]any:
ro := right.(map[string]any)
for _, key := range keys(l, ro) {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if lv, ok := l[key]; ok {
li = lv
}
@@ -363,13 +378,14 @@ func le(left, right interface{}) bool {
return lt(li, rai)
}
}
return true
default:
panic("unknown type")
}
}
func ge(left, right interface{}) bool {
func ge(left, right any) bool {
leftV, rightV := typeValue(left), typeValue(right)
if leftV != rightV {
return leftV >= rightV
@@ -385,15 +401,15 @@ func ge(left, right interface{}) bool {
return l >= right.(float64)
case string:
return l >= right.(string)
case []interface{}:
ra := right.([]interface{})
case []any:
ra := right.([]any)
max := len(l)
if len(ra) > max {
max = len(ra)
}
for i := 0; i < max; i++ {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if len(l) > i {
li = l[i]
}
@@ -405,13 +421,14 @@ func ge(left, right interface{}) bool {
return ge(li, rai)
}
}
return true
case map[string]interface{}:
ro := right.(map[string]interface{})
case map[string]any:
ro := right.(map[string]any)
for _, key := range keys(l, ro) {
var li interface{} = nil
var rai interface{} = nil
var li any
var rai any
if lv, ok := l[key]; ok {
li = lv
}
@@ -423,14 +440,15 @@ func ge(left, right interface{}) bool {
return gt(li, rai)
}
}
return true
default:
panic("unknown type")
}
}
func in(left, right interface{}) bool {
a, ok := right.([]interface{})
func in(left, right any) bool {
a, ok := right.([]any)
if !ok {
return false
}
@@ -439,23 +457,25 @@ func in(left, right interface{}) bool {
return true
}
}
return false
}
func like(left, right interface{}) (bool, error) {
func like(left, right any) (bool, error) {
return match(right.(string), left.(string))
}
func regexMatch(left, right interface{}) (bool, error) {
func regexMatch(left, right any) (bool, error) {
return regexp.Match(right.(string), []byte(left.(string)))
}
func regexNonMatch(left, right interface{}) (bool, error) {
func regexNonMatch(left, right any) (bool, error) {
m, err := regexp.Match(right.(string), []byte(left.(string)))
return !m, err
}
func typeValue(v interface{}) int {
func typeValue(v any) int {
switch v.(type) {
case nil:
return 0
@@ -465,9 +485,9 @@ func typeValue(v interface{}) int {
return 2
case string:
return 3
case []interface{}:
case []any:
return 4
case map[string]interface{}:
case map[string]any:
return 5
default:
panic("unknown type")
@@ -476,22 +496,25 @@ func typeValue(v interface{}) int {
// Ternary operator https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
func ternary(left, middle, right interface{}) interface{} {
func ternary(left, middle, right any) any {
if toBool(left) {
if middle != nil {
return middle
}
return left
}
return right
}
// Range operators https://www.arangodb.com/docs/3.7/aql/operators.html#range-operator
func aqlrange(left, right interface{}) []float64 {
func aqlrange(left, right any) []float64 {
var v []float64
for i := int(left.(float64)); i <= int(right.(float64)); i++ {
v = append(v, float64(i))
}
return v
}

View File

@@ -21,7 +21,7 @@ func (p *Parser) Parse(aql string) (t *Tree, err error) {
err = fmt.Errorf("%s", r)
}
}()
// Setup the input
// Set up the input
inputStream := antlr.NewInputStream(aql)
errorListener := &errorListener{}
@@ -52,7 +52,7 @@ type Tree struct {
prefix string
}
func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
func (t *Tree) Eval(values map[string]any) (i any, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%s", r)
@@ -65,6 +65,7 @@ func (t *Tree) Eval(values map[string]interface{}) (i interface{}, err error) {
if interpreter.errs != nil {
return nil, interpreter.errs[0]
}
return interpreter.stack[0], nil
}
@@ -103,7 +104,7 @@ type errorListener struct {
errs []error
}
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
el.errs = append(el.errs, fmt.Errorf("line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg))
}

View File

@@ -1,9 +1,11 @@
package caql
package caql_test
import (
"encoding/json"
"reflect"
"testing"
"github.com/SecurityBrewery/catalyst/caql"
)
type MockSearcher struct{}
@@ -13,11 +15,13 @@ func (m MockSearcher) Search(_ string) (ids []string, err error) {
}
func TestParseSAQLEval(t *testing.T) {
t.Parallel()
tests := []struct {
name string
saql string
wantRebuild string
wantValue interface{}
wantValue any
wantParseErr bool
wantRebuildErr bool
wantEvalErr bool
@@ -89,15 +93,15 @@ func TestParseSAQLEval(t *testing.T) {
// {name: "String 9", saql: `'this is a longer string.'`, wantRebuild: `"this is a longer string."`, wantValue: "this is a longer string."},
// {name: "String 10", saql: `'the path separator on Windows is \\'`, wantRebuild: `"the path separator on Windows is \\"`, wantValue: `the path separator on Windows is \`},
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []interface{}{}},
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []interface{}{true}},
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}},
{name: "Array 1", saql: "[]", wantRebuild: "[]", wantValue: []any{}},
{name: "Array 2", saql: `[true]`, wantRebuild: `[true]`, wantValue: []any{true}},
{name: "Array 3", saql: `[1, 2, 3]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
{
name: "Array 4", saql: `[-99, "yikes!", [false, ["no"], []], 1]`, wantRebuild: `[-99, "yikes!", [false, ["no"], []], 1]`,
wantValue: []interface{}{-99.0, "yikes!", []interface{}{false, []interface{}{"no"}, []interface{}{}}, float64(1)},
wantValue: []any{-99.0, "yikes!", []any{false, []any{"no"}, []any{}}, float64(1)},
},
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []interface{}{[]interface{}{"fox", "marshal"}}},
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []interface{}{float64(1), float64(2), float64(3)}},
{name: "Array 5", saql: `[["fox", "marshal"]]`, wantRebuild: `[["fox", "marshal"]]`, wantValue: []any{[]any{"fox", "marshal"}}},
{name: "Array 6", saql: `[1, 2, 3,]`, wantRebuild: `[1, 2, 3]`, wantValue: []any{float64(1), float64(2), float64(3)}},
{name: "Array Error 1", saql: "(1,2,3)", wantParseErr: true},
{name: "Array Access 1", saql: "u.friends[0]", wantRebuild: "u.friends[0]", wantValue: 7, values: `{"u": {"friends": [7,8,9]}}`},
@@ -105,14 +109,14 @@ func TestParseSAQLEval(t *testing.T) {
{name: "Array Access 3", saql: "u.friends[-1]", wantRebuild: "u.friends[-1]", wantValue: 9, values: `{"u": {"friends": [7,8,9]}}`},
{name: "Array Access 4", saql: "u.friends[-2]", wantRebuild: "u.friends[-2]", wantValue: 8, values: `{"u": {"friends": [7,8,9]}}`},
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]interface{}{}},
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}},
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]interface{}{"return": float64(1)}},
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]interface{}{"return": float64(1)}},
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]interface{}{"a": float64(1)}},
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]interface{}{"a": float64(1)}},
{name: "Object 1", saql: "{}", wantRebuild: "{}", wantValue: map[string]any{}},
{name: "Object 2", saql: `{a: 1}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
{name: "Object 3", saql: `{'a': 1}`, wantRebuild: `{'a': 1}`, wantValue: map[string]any{"a": float64(1)}},
{name: "Object 4", saql: `{"a": 1}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
{name: "Object 5", saql: `{'return': 1}`, wantRebuild: `{'return': 1}`, wantValue: map[string]any{"return": float64(1)}},
{name: "Object 6", saql: `{"return": 1}`, wantRebuild: `{"return": 1}`, wantValue: map[string]any{"return": float64(1)}},
{name: "Object 9", saql: `{a: 1,}`, wantRebuild: "{a: 1}", wantValue: map[string]any{"a": float64(1)}},
{name: "Object 10", saql: `{"a": 1,}`, wantRebuild: `{"a": 1}`, wantValue: map[string]any{"a": float64(1)}},
// {"Object 8", "{`return`: 1}", `{"return": 1}`, true},
// {"Object 7", "{´return´: 1}", `{"return": 1}`, true},
{name: "Object Error 1: return is a keyword", saql: `{like: 1}`, wantParseErr: true},
@@ -272,7 +276,7 @@ func TestParseSAQLEval(t *testing.T) {
{name: "Arithmetic 17", saql: `23 * {}`, wantRebuild: `23 * {}`, wantValue: 0},
{name: "Arithmetic 18", saql: `5 * [7]`, wantRebuild: `5 * [7]`, wantValue: 35},
{name: "Arithmetic 19", saql: `24 / "12"`, wantRebuild: `24 / "12"`, wantValue: 2},
{name: "Arithmetic Error 1: Divison by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
{name: "Arithmetic Error 1: Division by zero", saql: `1 / 0`, wantRebuild: `1 / 0`, wantValue: 0},
// https://www.arangodb.com/docs/3.7/aql/operators.html#ternary-operator
{name: "Ternary 1", saql: `u.age > 15 || u.active == true ? u.userId : null`, wantRebuild: `u.age > 15 OR u.active == true ? u.userId : null`, wantValue: 45, values: `{"u": {"active": true, "age": 2, "userId": 45}}`},
@@ -287,20 +291,24 @@ func TestParseSAQLEval(t *testing.T) {
{name: "Security 2", saql: `doc.value == 1 || true INSERT {foo: "bar"} IN collection //`, wantParseErr: true},
// https://www.arangodb.com/docs/3.7/aql/operators.html#operator-precedence
{name: "Precendence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
{name: "Precedence", saql: `2 > 15 && "a" != ""`, wantRebuild: `2 > 15 AND "a" != ""`, wantValue: false},
}
for _, tt := range tests {
parser := &Parser{
tt := tt
parser := &caql.Parser{
Searcher: &MockSearcher{},
}
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
expr, err := parser.Parse(tt.saql)
if (err != nil) != tt.wantParseErr {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
if expr != nil {
t.Error(expr.String())
}
return
}
if err != nil {
@@ -311,6 +319,7 @@ func TestParseSAQLEval(t *testing.T) {
if (err != nil) != tt.wantRebuildErr {
t.Error(expr.String())
t.Errorf("String() error = %v, wantErr %v", err, tt.wantParseErr)
return
}
if err != nil {
@@ -320,18 +329,19 @@ func TestParseSAQLEval(t *testing.T) {
t.Errorf("String() got = %v, want %v", got, tt.wantRebuild)
}
var myJson map[string]interface{}
var myJSON map[string]any
if tt.values != "" {
err = json.Unmarshal([]byte(tt.values), &myJson)
err = json.Unmarshal([]byte(tt.values), &myJSON)
if err != nil {
t.Fatal(err)
}
}
value, err := expr.Eval(myJson)
value, err := expr.Eval(myJSON)
if (err != nil) != tt.wantEvalErr {
t.Error(expr.String())
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantParseErr)
return
}
if err != nil {

View File

@@ -22,19 +22,18 @@
package caql
import "sort"
type (
Set struct {
hash map[interface{}]nothing
}
nothing struct{}
import (
"sort"
)
// Create a new set
func New(initial ...interface{}) *Set {
s := &Set{make(map[interface{}]nothing)}
type Set struct {
hash map[any]nothing
}
type nothing struct{}
func NewSet(initial ...any) *Set {
s := &Set{make(map[any]nothing)}
for _, v := range initial {
s.Insert(v)
@@ -43,9 +42,8 @@ func New(initial ...interface{}) *Set {
return s
}
// Find the difference between two sets
func (s *Set) Difference(set *Set) *Set {
n := make(map[interface{}]nothing)
n := make(map[any]nothing)
for k := range s.hash {
if _, exists := set.hash[k]; !exists {
@@ -56,27 +54,18 @@ func (s *Set) Difference(set *Set) *Set {
return &Set{n}
}
// Call f for each item in the set
func (s *Set) Do(f func(interface{})) {
for k := range s.hash {
f(k)
}
}
// Test to see whether or not the element is in the set
func (s *Set) Has(element interface{}) bool {
func (s *Set) Has(element any) bool {
_, exists := s.hash[element]
return exists
}
// Add an element to the set
func (s *Set) Insert(element interface{}) {
func (s *Set) Insert(element any) {
s.hash[element] = nothing{}
}
// Find the intersection of two sets
func (s *Set) Intersection(set *Set) *Set {
n := make(map[interface{}]nothing)
n := make(map[any]nothing)
for k := range s.hash {
if _, exists := set.hash[k]; exists {
@@ -87,23 +76,20 @@ func (s *Set) Intersection(set *Set) *Set {
return &Set{n}
}
// Return the number of items in the set
func (s *Set) Len() int {
return len(s.hash)
}
// Test whether or not this set is a proper subset of "set"
func (s *Set) ProperSubsetOf(set *Set) bool {
return s.SubsetOf(set) && s.Len() < set.Len()
}
// Remove an element from the set
func (s *Set) Remove(element interface{}) {
func (s *Set) Remove(element any) {
delete(s.hash, element)
}
func (s *Set) Minus(set *Set) *Set {
n := make(map[interface{}]nothing)
n := make(map[any]nothing)
for k := range s.hash {
n[k] = nothing{}
}
@@ -115,7 +101,6 @@ func (s *Set) Minus(set *Set) *Set {
return &Set{n}
}
// Test whether or not this set is a subset of "set"
func (s *Set) SubsetOf(set *Set) bool {
if s.Len() > set.Len() {
return false
@@ -125,12 +110,12 @@ func (s *Set) SubsetOf(set *Set) bool {
return false
}
}
return true
}
// Find the union of two sets
func (s *Set) Union(set *Set) *Set {
n := make(map[interface{}]nothing)
n := make(map[any]nothing)
for k := range s.hash {
n[k] = nothing{}
@@ -142,8 +127,8 @@ func (s *Set) Union(set *Set) *Set {
return &Set{n}
}
func (s *Set) Values() []interface{} {
values := []interface{}{}
func (s *Set) Values() []any {
values := []any{}
for k := range s.hash {
values = append(values, k)

View File

@@ -27,7 +27,9 @@ import (
)
func Test(t *testing.T) {
s := New()
t.Parallel()
s := NewSet()
s.Insert(5)
@@ -50,8 +52,8 @@ func Test(t *testing.T) {
}
// Difference
s1 := New(1, 2, 3, 4, 5, 6)
s2 := New(4, 5, 6)
s1 := NewSet(1, 2, 3, 4, 5, 6)
s2 := NewSet(4, 5, 6)
s3 := s1.Difference(s2)
if s3.Len() != 3 {
@@ -73,7 +75,7 @@ func Test(t *testing.T) {
}
// Union
s4 := New(7, 8, 9)
s4 := NewSet(7, 8, 9)
s3 = s2.Union(s4)
if s3.Len() != 6 {
@@ -92,5 +94,4 @@ func Test(t *testing.T) {
if s1.ProperSubsetOf(s1) {
t.Errorf("set should not be a subset of itself")
}
}

View File

@@ -39,8 +39,10 @@ func unquote(s string) (string, error) {
buf = append(buf, s[i])
}
}
return string(buf), nil
}
return s, nil
}
if quote != '"' && quote != '\'' {
@@ -75,5 +77,6 @@ func unquote(s string) (string, error) {
buf = append(buf, runeTmp[:n]...)
}
}
return string(buf), nil
}

View File

@@ -8,26 +8,25 @@
package caql
import (
"errors"
"strconv"
"testing"
)
type quoteTest struct {
in string
out string
ascii string
graphic string
in string
out string
}
var quotetests = []quoteTest{
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`, ascii: `"\a\b\f\r\n\t\v"`, graphic: `"\a\b\f\r\n\t\v"`},
{"\\", `"\\"`, `"\\"`, `"\\"`},
{"abc\xffdef", `"abc\xffdef"`, `"abc\xffdef"`, `"abc\xffdef"`},
{"\u263a", `"☺"`, `"\u263a"`, `"☺"`},
{"\U0010ffff", `"\U0010ffff"`, `"\U0010ffff"`, `"\U0010ffff"`},
{"\x04", `"\x04"`, `"\x04"`, `"\x04"`},
{in: "\a\b\f\r\n\t\v", out: `"\a\b\f\r\n\t\v"`},
{"\\", `"\\"`},
{"abc\xffdef", `"abc\xffdef"`},
{"\u263a", `"☺"`},
{"\U0010ffff", `"\U0010ffff"`},
{"\x04", `"\x04"`},
// Some non-printable but graphic runes. Final column is double-quoted.
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`, `"!\u00a0!\u2000!\u3000!"`, "\"!\u00a0!\u2000!\u3000!\""},
{"!\u00a0!\u2000!\u3000!", `"!\u00a0!\u2000!\u3000!"`},
}
type unQuoteTest struct {
@@ -104,6 +103,8 @@ var misquoted = []string{
}
func TestUnquote(t *testing.T) {
t.Parallel()
for _, tt := range unquotetests {
if out, err := unquote(tt.in); err != nil || out != tt.out {
t.Errorf("unquote(%#q) = %q, %v want %q, nil", tt.in, out, err, tt.out)
@@ -118,7 +119,7 @@ func TestUnquote(t *testing.T) {
}
for _, s := range misquoted {
if out, err := unquote(s); out != "" || err != strconv.ErrSyntax {
if out, err := unquote(s); out != "" || !errors.Is(err, strconv.ErrSyntax) {
t.Errorf("unquote(%#q) = %q, %v want %q, %v", s, out, err, "", strconv.ErrSyntax)
}
}

View File

@@ -48,6 +48,7 @@ Pattern:
// using the star
if ok && (len(t) == 0 || len(pattern) > 0) {
name = t
continue
}
if err != nil {
@@ -64,6 +65,7 @@ Pattern:
continue
}
name = t
continue Pattern
}
if err != nil {
@@ -79,8 +81,10 @@ Pattern:
return false, err
}
}
return false, nil
}
return len(name) == 0, nil
}
@@ -104,6 +108,7 @@ Scan:
break Scan
}
}
return star, pattern[0:i], pattern[i:]
}
@@ -120,7 +125,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
failed = true
}
switch chunk[0] {
case '_':
if !failed {
if s[0] == '/' {
@@ -130,14 +134,13 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
s = s[n:]
}
chunk = chunk[1:]
case '\\':
chunk = chunk[1:]
if len(chunk) == 0 {
return "", false, ErrBadPattern
}
fallthrough
fallthrough
default:
if !failed {
if chunk[0] != s[0] {
@@ -151,5 +154,6 @@ func matchChunk(chunk, s string) (rest string, ok bool, err error) {
if failed {
return "", false, nil
}
return s, true, nil
}

View File

@@ -7,7 +7,10 @@
package caql
import "testing"
import (
"errors"
"testing"
)
type MatchTest struct {
pattern, s string
@@ -41,9 +44,11 @@ var matchTests = []MatchTest{
}
func TestMatch(t *testing.T) {
t.Parallel()
for _, tt := range matchTests {
ok, err := match(tt.pattern, tt.s)
if ok != tt.match || err != tt.err {
if ok != tt.match || !errors.Is(err, tt.err) {
t.Errorf("match(%#q, %#q) = %v, %v want %v, %v", tt.pattern, tt.s, ok, err, tt.match, tt.err)
}
}