mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-06 23:32:47 +01:00
126 lines
2.5 KiB
Go
126 lines
2.5 KiB
Go
package test
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
|
|
"github.com/SecurityBrewery/catalyst/generated/api"
|
|
ctime "github.com/SecurityBrewery/catalyst/generated/time"
|
|
)
|
|
|
|
type testClock struct{}
|
|
|
|
func (testClock) Now() time.Time {
|
|
return time.Date(2021, 12, 12, 12, 12, 12, 12, time.UTC)
|
|
}
|
|
|
|
func TestServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctime.DefaultClock = testClock{}
|
|
|
|
for _, tt := range api.Tests {
|
|
tt := tt
|
|
t.Run(tt.Name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, _, _, _, _, db, _, server, cleanup, err := Server(t)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer cleanup()
|
|
|
|
if err := SetupTestData(ctx, db); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
// setup request
|
|
var req *http.Request
|
|
if tt.Args.Data != nil {
|
|
b, err := json.Marshal(tt.Args.Data)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req = httptest.NewRequest(strings.ToUpper(tt.Args.Method), tt.Args.URL, bytes.NewBuffer(b))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
} else {
|
|
req = httptest.NewRequest(strings.ToUpper(tt.Args.Method), tt.Args.URL, nil)
|
|
}
|
|
|
|
// run request
|
|
server.ServeHTTP(w, req)
|
|
|
|
result := w.Result()
|
|
|
|
// assert results
|
|
if result.StatusCode != tt.Want.Status {
|
|
msg, _ := io.ReadAll(result.Body)
|
|
|
|
t.Fatalf("Status got = %v (%s), want %v", result.Status, msg, tt.Want.Status)
|
|
}
|
|
if tt.Want.Status != http.StatusNoContent {
|
|
jsonEqual(t, tt.Name, result.Body, tt.Want.Body)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func jsonEqual(t *testing.T, name string, got io.Reader, want any) {
|
|
t.Helper()
|
|
|
|
var gotObject, wantObject any
|
|
|
|
// load bytes
|
|
wantBytes, err := json.Marshal(want)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
gotBytes, err := io.ReadAll(got)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var fields []string
|
|
|
|
if name == "CreateUser" {
|
|
fields = append(fields, "secret")
|
|
}
|
|
if name == "RunJob" {
|
|
fields = append(fields, "id", "status")
|
|
}
|
|
|
|
for _, field := range fields {
|
|
gField := gjson.GetBytes(wantBytes, field)
|
|
if gField.Exists() && gjson.GetBytes(gotBytes, field).Exists() {
|
|
gotBytes, err = sjson.SetBytes(gotBytes, field, gField.Value())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// normalize bytes
|
|
if err = json.Unmarshal(wantBytes, &wantObject); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := json.Unmarshal(gotBytes, &gotObject); err != nil {
|
|
t.Fatal(string(gotBytes), err)
|
|
}
|
|
|
|
// compare
|
|
assert.Equal(t, wantObject, gotObject)
|
|
}
|