Files
catalyst/app/auth/middleware_test.go
2025-09-02 21:58:08 +02:00

189 lines
4.3 KiB
Go

package auth
import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
"github.com/SecurityBrewery/catalyst/app/openapi"
)
func mockHandler(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(`{"message":"OK"}`))
}
func TestService_ValidateScopes(t *testing.T) {
t.Parallel()
type args struct {
requiredScopes []string
permissions []string
next http.HandlerFunc
}
tests := []struct {
name string
args args
want httptest.ResponseRecorder
}{
{
name: "no scopes",
args: args{
requiredScopes: []string{"user:read"},
permissions: []string{},
next: mockHandler,
},
want: httptest.ResponseRecorder{
Code: http.StatusUnauthorized,
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
},
},
{
name: "insufficient scopes",
args: args{
requiredScopes: []string{"user:write"},
permissions: []string{"user:read"},
next: mockHandler,
},
want: httptest.ResponseRecorder{
Code: http.StatusUnauthorized,
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
},
},
{
name: "sufficient scopes",
args: args{
requiredScopes: []string{"user:read"},
permissions: []string{"user:read", "user:write"},
next: mockHandler,
},
want: httptest.ResponseRecorder{
Code: http.StatusOK,
Body: bytes.NewBufferString(`{"message":"OK"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := ValidateScopesStrict(func(_ context.Context, w http.ResponseWriter, r *http.Request, _ any) (response any, err error) {
tt.args.next(w, r)
return w, nil
}, "")
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
//nolint: staticcheck
r = r.WithContext(context.WithValue(r.Context(), openapi.OAuth2Scopes, tt.args.requiredScopes))
r = usercontext.PermissionRequest(r, tt.args.permissions)
if _, err := handler(r.Context(), w, r, r); err != nil {
return
}
assert.Equal(t, tt.want.Code, w.Code, "response code should match expected value")
assert.JSONEq(t, tt.want.Body.String(), w.Body.String(), "response body should match expected value")
})
}
}
func Test_hasScope(t *testing.T) {
t.Parallel()
type args struct {
scopes []string
requiredScopes []string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "no scopes",
args: args{
scopes: []string{},
requiredScopes: []string{"user:read"},
},
want: false,
},
{
name: "missing required scope",
args: args{
scopes: []string{"user:read"},
requiredScopes: []string{"user:write"},
},
},
{
name: "has required scope",
args: args{
scopes: []string{"user:read", "user:write"},
requiredScopes: []string{"user:read"},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equalf(t, tt.want, hasScope(tt.args.scopes, tt.args.requiredScopes), "hasScope(%v, %v)", tt.args.scopes, tt.args.requiredScopes)
})
}
}
func Test_requiredScopes(t *testing.T) {
t.Parallel()
type args struct {
r *http.Request
}
tests := []struct {
name string
args args
want []string
wantErr assert.ErrorAssertionFunc
}{
{
name: "no required scopes",
args: args{
r: httptest.NewRequest(http.MethodGet, "/", nil),
},
want: nil,
wantErr: assert.NoError,
},
{
name: "valid required scopes",
args: args{
//nolint: staticcheck
r: httptest.NewRequest(http.MethodGet, "/", nil).WithContext(context.WithValue(t.Context(), openapi.OAuth2Scopes, []string{"user:read", "user:write"})),
},
want: []string{"user:read", "user:write"},
wantErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := requiredScopes(tt.args.r.Context())
if !tt.wantErr(t, err, fmt.Sprintf("requiredScopes(%v)", tt.args.r)) {
return
}
assert.Equalf(t, tt.want, got, "requiredScopes(%v)", tt.args.r)
})
}
}