mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-06 07:12:46 +01:00
189 lines
4.3 KiB
Go
189 lines
4.3 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
|
|
"github.com/SecurityBrewery/catalyst/app/openapi"
|
|
)
|
|
|
|
func mockHandler(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = w.Write([]byte(`{"message":"OK"}`))
|
|
}
|
|
|
|
func TestService_ValidateScopes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type args struct {
|
|
requiredScopes []string
|
|
permissions []string
|
|
next http.HandlerFunc
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want httptest.ResponseRecorder
|
|
}{
|
|
{
|
|
name: "no scopes",
|
|
args: args{
|
|
requiredScopes: []string{"user:read"},
|
|
permissions: []string{},
|
|
next: mockHandler,
|
|
},
|
|
want: httptest.ResponseRecorder{
|
|
Code: http.StatusUnauthorized,
|
|
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
|
|
},
|
|
},
|
|
{
|
|
name: "insufficient scopes",
|
|
args: args{
|
|
requiredScopes: []string{"user:write"},
|
|
permissions: []string{"user:read"},
|
|
next: mockHandler,
|
|
},
|
|
want: httptest.ResponseRecorder{
|
|
Code: http.StatusUnauthorized,
|
|
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
|
|
},
|
|
},
|
|
{
|
|
name: "sufficient scopes",
|
|
args: args{
|
|
requiredScopes: []string{"user:read"},
|
|
permissions: []string{"user:read", "user:write"},
|
|
next: mockHandler,
|
|
},
|
|
want: httptest.ResponseRecorder{
|
|
Code: http.StatusOK,
|
|
Body: bytes.NewBufferString(`{"message":"OK"}`),
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handler := ValidateScopesStrict(func(_ context.Context, w http.ResponseWriter, r *http.Request, _ any) (response any, err error) {
|
|
tt.args.next(w, r)
|
|
|
|
return w, nil
|
|
}, "")
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
//nolint: staticcheck
|
|
r = r.WithContext(context.WithValue(r.Context(), openapi.OAuth2Scopes, tt.args.requiredScopes))
|
|
r = usercontext.PermissionRequest(r, tt.args.permissions)
|
|
|
|
if _, err := handler(r.Context(), w, r, r); err != nil {
|
|
return
|
|
}
|
|
|
|
assert.Equal(t, tt.want.Code, w.Code, "response code should match expected value")
|
|
assert.JSONEq(t, tt.want.Body.String(), w.Body.String(), "response body should match expected value")
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_hasScope(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type args struct {
|
|
scopes []string
|
|
requiredScopes []string
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want bool
|
|
}{
|
|
{
|
|
name: "no scopes",
|
|
args: args{
|
|
scopes: []string{},
|
|
requiredScopes: []string{"user:read"},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "missing required scope",
|
|
args: args{
|
|
scopes: []string{"user:read"},
|
|
requiredScopes: []string{"user:write"},
|
|
},
|
|
},
|
|
{
|
|
name: "has required scope",
|
|
args: args{
|
|
scopes: []string{"user:read", "user:write"},
|
|
requiredScopes: []string{"user:read"},
|
|
},
|
|
want: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equalf(t, tt.want, hasScope(tt.args.scopes, tt.args.requiredScopes), "hasScope(%v, %v)", tt.args.scopes, tt.args.requiredScopes)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_requiredScopes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type args struct {
|
|
r *http.Request
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want []string
|
|
wantErr assert.ErrorAssertionFunc
|
|
}{
|
|
{
|
|
name: "no required scopes",
|
|
args: args{
|
|
r: httptest.NewRequest(http.MethodGet, "/", nil),
|
|
},
|
|
want: nil,
|
|
wantErr: assert.NoError,
|
|
},
|
|
{
|
|
name: "valid required scopes",
|
|
args: args{
|
|
//nolint: staticcheck
|
|
r: httptest.NewRequest(http.MethodGet, "/", nil).WithContext(context.WithValue(t.Context(), openapi.OAuth2Scopes, []string{"user:read", "user:write"})),
|
|
},
|
|
want: []string{"user:read", "user:write"},
|
|
wantErr: assert.NoError,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got, err := requiredScopes(tt.args.r.Context())
|
|
if !tt.wantErr(t, err, fmt.Sprintf("requiredScopes(%v)", tt.args.r)) {
|
|
return
|
|
}
|
|
|
|
assert.Equalf(t, tt.want, got, "requiredScopes(%v)", tt.args.r)
|
|
})
|
|
}
|
|
}
|