From fd8e793361264b52bf49085e0e60e3d689cb551a Mon Sep 17 00:00:00 2001 From: Jonas Plum Date: Sun, 21 Aug 2022 22:23:27 +0200 Subject: [PATCH] Fix sorting on multiple ticket fields (#412) --- database/ticket.go | 6 +-- generated/api/api.go | 52 +++++++++++++---------- generated/api/api_test.go | 86 +++++++++++++++++++++++++++++++++++++++ service/ticket.go | 20 ++++----- 4 files changed, 130 insertions(+), 34 deletions(-) create mode 100644 generated/api/api_test.go diff --git a/database/ticket.go b/database/ticket.go index 45d2154..02b1f9a 100644 --- a/database/ticket.go +++ b/database/ticket.go @@ -571,7 +571,7 @@ func (db *Database) TicketCount(ctx context.Context, typequery, filterquery stri } func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]any) string { - sort := "" + sortQuery := "" if len(paramsSort) > 0 { var sorts []string for i, column := range paramsSort { @@ -582,10 +582,10 @@ func sortQuery(paramsSort []string, paramsDesc []bool, bindVars map[string]any) sorts = append(sorts, colsort) bindVars[fmt.Sprintf("column%d", i)] = column } - sort = "SORT " + strings.Join(sorts, ", ") + sortQuery = "SORT " + strings.Join(sorts, ", ") } - return sort + return sortQuery } func mergeMaps(a map[string]any, b map[string]any) map[string]any { diff --git a/generated/api/api.go b/generated/api/api.go index d6f5917..6a875e3 100755 --- a/generated/api/api.go +++ b/generated/api/api.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "strconv" + "strings" "github.com/go-chi/chi" "github.com/xeipuuv/gojsonschema" @@ -58,30 +59,12 @@ func parseQueryBool(r *http.Request, s string) (bool, error) { } func parseQueryStringArray(r *http.Request, key string) ([]string, error) { - stringArray, ok := r.URL.Query()[key] - if !ok { - return nil, nil - } - return removeEmpty(stringArray), nil -} - -func removeEmpty(l []string) []string { - var stringArray []string - for _, s := range l { - if s == "" { - continue - } - stringArray = append(stringArray, s) - } - - return stringArray + return parseQueryArray(r, key), nil } func parseQueryBoolArray(r *http.Request, key string) ([]bool, error) { - stringArray, ok := r.URL.Query()[key] - if !ok { - return nil, nil - } + stringArray := parseQueryArray(r, key) + var boolArray []bool for _, s := range stringArray { if s == "" { @@ -97,6 +80,33 @@ func parseQueryBoolArray(r *http.Request, key string) ([]bool, error) { return boolArray, nil } +func parseQueryArray(r *http.Request, key string) []string { + stringArray, ok := r.URL.Query()[key] + if !ok { + return nil + } + + if len(stringArray) == 0 { + return nil + } + + stringArray = strings.Split(stringArray[0], ",") + + return removeEmpty(stringArray) +} + +func removeEmpty(l []string) []string { + var stringArray []string + for _, s := range l { + if s == "" { + continue + } + stringArray = append(stringArray, s) + } + + return stringArray +} + func parseQueryOptionalInt(r *http.Request, key string) (*int, error) { s := r.URL.Query().Get(key) if s == "" { diff --git a/generated/api/api_test.go b/generated/api/api_test.go new file mode 100644 index 0000000..a085545 --- /dev/null +++ b/generated/api/api_test.go @@ -0,0 +1,86 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func Test_parseQueryOptionalBoolArray(t *testing.T) { + type args struct { + r *http.Request + key string + } + tests := []struct { + name string + args args + want []bool + wantErr bool + }{ + { + name: "bool array", + args: args{ + r: httptest.NewRequest( + http.MethodGet, + "https://try.catalyst-soar.com/api/tickets?type=alert&offset=0&count=10&sort=status%2Cowner%2Ccreated&desc=true%2Cfalse%2Cfalse&query=status+%3D%3D+%27open%27+AND+%28owner+%3D%3D+%27eve%27+OR+%21owner%29", + nil, + ), + key: "desc", + }, + want: []bool{true, false, false}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseQueryOptionalBoolArray(tt.args.r, tt.args.key) + if (err != nil) != tt.wantErr { + t.Errorf("parseQueryOptionalBoolArray() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseQueryOptionalBoolArray() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseQueryOptionalStringArray(t *testing.T) { + type args struct { + r *http.Request + key string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "string array", + args: args{ + r: httptest.NewRequest( + http.MethodGet, + "https://try.catalyst-soar.com/api/tickets?type=alert&offset=0&count=10&sort=status%2Cowner%2Ccreated&desc=true%2Cfalse%2Cfalse&query=status+%3D%3D+%27open%27+AND+%28owner+%3D%3D+%27eve%27+OR+%21owner%29", + nil, + ), + key: "sort", + }, + want: []string{"status", "owner", "created"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseQueryOptionalStringArray(tt.args.r, tt.args.key) + if (err != nil) != tt.wantErr { + t.Errorf("parseQueryOptionalStringArray() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseQueryOptionalStringArray() got = %v, want %v", got, tt.want) + } + }) + } +} \ No newline at end of file diff --git a/service/ticket.go b/service/ticket.go index 9fa3ed7..b7b6f0e 100644 --- a/service/ticket.go +++ b/service/ticket.go @@ -38,27 +38,27 @@ func ticketIDs(ticketResponses []*model.TicketResponse) []driver.DocumentID { return ids } -func (s *Service) ListTickets(ctx context.Context, s3 *string, i *int, i2 *int, strings []string, bools []bool, s2 *string) (*model.TicketList, error) { +func (s *Service) ListTickets(ctx context.Context, ticketType *string, offsetP, countP *int, sort []string, descending []bool, queryP *string) (*model.TicketList, error) { q := "" - if s2 != nil && *s2 != "" { - q = *s2 + if queryP != nil && *queryP != "" { + q = *queryP } t := "" - if s3 != nil && *s3 != "" { - t = *s3 + if ticketType != nil && *ticketType != "" { + t = *ticketType } offset := int64(0) - if i != nil { - offset = int64(*i) + if offsetP != nil { + offset = int64(*offsetP) } count := int64(25) - if i2 != nil { - count = int64(*i2) + if countP != nil { + count = int64(*countP) } - return s.database.TicketList(ctx, t, q, strings, bools, offset, count) + return s.database.TicketList(ctx, t, q, sort, descending, offset, count) } func (s *Service) CreateTicket(ctx context.Context, form *model.TicketForm) (doc *model.TicketResponse, err error) {