mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2026-01-12 09:11:23 +01:00
refactor: remove pocketbase (#1138)
This commit is contained in:
16
app/auth/errorjson.go
Normal file
16
app/auth/errorjson.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func unauthorizedJSON(w http.ResponseWriter, msg string) {
|
||||
errorJSON(w, http.StatusUnauthorized, msg)
|
||||
}
|
||||
|
||||
func errorJSON(w http.ResponseWriter, status int, msg string) {
|
||||
w.WriteHeader(status)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = fmt.Fprintf(w, `{"status": %d, "error": %q, "message": %q}`, status, http.StatusText(status), msg)
|
||||
}
|
||||
158
app/auth/middleware.go
Normal file
158
app/auth/middleware.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
strictnethttp "github.com/oapi-codegen/runtime/strictmiddleware/nethttp"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/openapi"
|
||||
)
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
|
||||
func Middleware(queries *sqlc.Queries) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/config" {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := r.Header.Get("Authorization")
|
||||
bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix)
|
||||
|
||||
user, claims, err := verifyAccessToken(r.Context(), bearerToken, queries)
|
||||
if err != nil {
|
||||
slog.ErrorContext(r.Context(), "invalid bearer token", "error", err)
|
||||
|
||||
unauthorizedJSON(w, "invalid bearer token")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
scopes, err := scopes(claims)
|
||||
if err != nil {
|
||||
slog.ErrorContext(r.Context(), "failed to get scopes from token", "error", err)
|
||||
|
||||
unauthorizedJSON(w, "failed to get scopes")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Set the user in the context
|
||||
r = usercontext.UserRequest(r, user)
|
||||
r = usercontext.PermissionRequest(r, scopes)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateFileScopes(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requiredScopes := []string{"file:read"}
|
||||
if slices.Contains([]string{http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete}, r.Method) {
|
||||
requiredScopes = []string{"file:write"}
|
||||
}
|
||||
|
||||
if err := validateScopes(r.Context(), requiredScopes); err != nil {
|
||||
slog.ErrorContext(r.Context(), "failed to validate scopes", "error", err)
|
||||
unauthorizedJSON(w, "missing required scopes")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func ValidateScopesStrict(next strictnethttp.StrictHTTPHandlerFunc, _ string) strictnethttp.StrictHTTPHandlerFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
|
||||
requiredScopes, err := requiredScopes(ctx)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "failed to get required scopes", "error", err)
|
||||
unauthorizedJSON(w, "failed to get required scopes")
|
||||
|
||||
return nil, fmt.Errorf("failed to get required scopes: %w", err)
|
||||
}
|
||||
|
||||
if err := validateScopes(ctx, requiredScopes); err != nil {
|
||||
slog.ErrorContext(ctx, "failed to validate scopes", "error", err)
|
||||
unauthorizedJSON(w, "missing required scopes")
|
||||
|
||||
return nil, fmt.Errorf("missing required scopes: %w", err)
|
||||
}
|
||||
|
||||
return next(ctx, w, r, request)
|
||||
}
|
||||
}
|
||||
|
||||
func LogError(next strictnethttp.StrictHTTPHandlerFunc, _ string) strictnethttp.StrictHTTPHandlerFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
|
||||
re, err := next(ctx, w, r, request)
|
||||
if err != nil {
|
||||
if err.Error() == "context canceled" {
|
||||
// This is a common error when the request is canceled, e.g., by the client.
|
||||
// We can ignore this error as it does not indicate a problem with the handler.
|
||||
return re, nil
|
||||
}
|
||||
|
||||
slog.ErrorContext(ctx, "handler error", "error", err, "method", r.Method, "path", r.URL.Path)
|
||||
}
|
||||
|
||||
return re, err
|
||||
}
|
||||
}
|
||||
|
||||
func validateScopes(ctx context.Context, requiredScopes []string) error {
|
||||
if len(requiredScopes) > 0 {
|
||||
permissions, ok := usercontext.PermissionFromContext(ctx)
|
||||
if !ok {
|
||||
return errors.New("missing permissions")
|
||||
}
|
||||
|
||||
if !hasScope(permissions, requiredScopes) {
|
||||
return fmt.Errorf("missing required scopes: %v", requiredScopes)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func requiredScopes(ctx context.Context) ([]string, error) {
|
||||
requiredScopesValue := ctx.Value(openapi.OAuth2Scopes)
|
||||
if requiredScopesValue == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
requiredScopes, ok := requiredScopesValue.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid required scopes type: %T", requiredScopesValue)
|
||||
}
|
||||
|
||||
return requiredScopes, nil
|
||||
}
|
||||
|
||||
func hasScope(scopes []string, requiredScopes []string) bool {
|
||||
if slices.Contains(scopes, "admin") {
|
||||
// If the user has admin scope, they can access everything
|
||||
return true
|
||||
}
|
||||
|
||||
for _, s := range requiredScopes {
|
||||
if !slices.Contains(scopes, s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
188
app/auth/middleware_test.go
Normal file
188
app/auth/middleware_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
|
||||
"github.com/SecurityBrewery/catalyst/app/openapi"
|
||||
)
|
||||
|
||||
func mockHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"message":"OK"}`))
|
||||
}
|
||||
|
||||
func TestService_ValidateScopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
requiredScopes []string
|
||||
permissions []string
|
||||
next http.HandlerFunc
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want httptest.ResponseRecorder
|
||||
}{
|
||||
{
|
||||
name: "no scopes",
|
||||
args: args{
|
||||
requiredScopes: []string{"user:read"},
|
||||
permissions: []string{},
|
||||
next: mockHandler,
|
||||
},
|
||||
want: httptest.ResponseRecorder{
|
||||
Code: http.StatusUnauthorized,
|
||||
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insufficient scopes",
|
||||
args: args{
|
||||
requiredScopes: []string{"user:write"},
|
||||
permissions: []string{"user:read"},
|
||||
next: mockHandler,
|
||||
},
|
||||
want: httptest.ResponseRecorder{
|
||||
Code: http.StatusUnauthorized,
|
||||
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sufficient scopes",
|
||||
args: args{
|
||||
requiredScopes: []string{"user:read"},
|
||||
permissions: []string{"user:read", "user:write"},
|
||||
next: mockHandler,
|
||||
},
|
||||
want: httptest.ResponseRecorder{
|
||||
Code: http.StatusOK,
|
||||
Body: bytes.NewBufferString(`{"message":"OK"}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := ValidateScopesStrict(func(_ context.Context, w http.ResponseWriter, r *http.Request, _ any) (response any, err error) {
|
||||
tt.args.next(w, r)
|
||||
|
||||
return w, nil
|
||||
}, "")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
//nolint: staticcheck
|
||||
r = r.WithContext(context.WithValue(r.Context(), openapi.OAuth2Scopes, tt.args.requiredScopes))
|
||||
r = usercontext.PermissionRequest(r, tt.args.permissions)
|
||||
|
||||
if _, err := handler(r.Context(), w, r, r); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.want.Code, w.Code, "response code should match expected value")
|
||||
assert.JSONEq(t, tt.want.Body.String(), w.Body.String(), "response body should match expected value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_hasScope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
scopes []string
|
||||
requiredScopes []string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no scopes",
|
||||
args: args{
|
||||
scopes: []string{},
|
||||
requiredScopes: []string{"user:read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing required scope",
|
||||
args: args{
|
||||
scopes: []string{"user:read"},
|
||||
requiredScopes: []string{"user:write"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "has required scope",
|
||||
args: args{
|
||||
scopes: []string{"user:read", "user:write"},
|
||||
requiredScopes: []string{"user:read"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equalf(t, tt.want, hasScope(tt.args.scopes, tt.args.requiredScopes), "hasScope(%v, %v)", tt.args.scopes, tt.args.requiredScopes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_requiredScopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
r *http.Request
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "no required scopes",
|
||||
args: args{
|
||||
r: httptest.NewRequest(http.MethodGet, "/", nil),
|
||||
},
|
||||
want: nil,
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "valid required scopes",
|
||||
args: args{
|
||||
//nolint: staticcheck
|
||||
r: httptest.NewRequest(http.MethodGet, "/", nil).WithContext(context.WithValue(t.Context(), openapi.OAuth2Scopes, []string{"user:read", "user:write"})),
|
||||
},
|
||||
want: []string{"user:read", "user:write"},
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := requiredScopes(tt.args.r.Context())
|
||||
if !tt.wantErr(t, err, fmt.Sprintf("requiredScopes(%v)", tt.args.r)) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, tt.want, got, "requiredScopes(%v)", tt.args.r)
|
||||
})
|
||||
}
|
||||
}
|
||||
32
app/auth/password/password.go
Normal file
32
app/auth/password/password.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func Hash(password string) (hashedPassword, tokenKey string, err error) {
|
||||
hashedPasswordB, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
tokenKey, err = GenerateTokenKey()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return string(hashedPasswordB), tokenKey, nil
|
||||
}
|
||||
|
||||
func GenerateTokenKey() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
67
app/auth/password/password_test.go
Normal file
67
app/auth/password/password_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
password string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "Hash valid password",
|
||||
args: args{
|
||||
password: "securePassword123!",
|
||||
},
|
||||
wantErr: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Long password",
|
||||
args: args{
|
||||
password: strings.Repeat("a", 75),
|
||||
},
|
||||
wantErr: require.Error,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotHashedPassword, gotTokenKey, err := Hash(tt.args.password)
|
||||
tt.wantErr(t, err, "Hash() should not return an error")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotEmpty(t, gotHashedPassword, "Hash() gotHashedPassword should not be empty")
|
||||
assert.NotEmpty(t, gotTokenKey, "Hash() gotTokenKey should not be empty")
|
||||
|
||||
require.NoError(t, bcrypt.CompareHashAndPassword([]byte(gotHashedPassword), []byte(tt.args.password)), "Hash() hashed password does not match original password")
|
||||
|
||||
assert.GreaterOrEqual(t, len(gotTokenKey), 43, "Hash() gotTokenKey should be at least 43 characters long")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTokenKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tokenKey, err := GenerateTokenKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenKey, "GenerateTokenKey() tokenKey should not be empty")
|
||||
assert.GreaterOrEqual(t, len(tokenKey), 43, "GenerateTokenKey() tokenKey should be at least 43 characters long")
|
||||
}
|
||||
73
app/auth/permission.go
Normal file
73
app/auth/permission.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
var (
|
||||
TicketReadPermission = "ticket:read"
|
||||
TicketWritePermission = "ticket:write"
|
||||
FileReadPermission = "file:read"
|
||||
FileWritePermission = "file:write"
|
||||
TypeReadPermission = "type:read"
|
||||
TypeWritePermission = "type:write"
|
||||
UserReadPermission = "user:read"
|
||||
UserWritePermission = "user:write"
|
||||
GroupReadPermission = "group:read"
|
||||
GroupWritePermission = "group:write"
|
||||
ReactionReadPermission = "reaction:read"
|
||||
ReactionWritePermission = "reaction:write"
|
||||
WebhookReadPermission = "webhook:read"
|
||||
WebhookWritePermission = "webhook:write"
|
||||
SettingsReadPermission = "settings:read"
|
||||
SettingsWritePermission = "settings:write"
|
||||
)
|
||||
|
||||
func All() []string {
|
||||
return []string{
|
||||
TicketReadPermission,
|
||||
TicketWritePermission,
|
||||
FileReadPermission,
|
||||
FileWritePermission,
|
||||
TypeReadPermission,
|
||||
TypeWritePermission,
|
||||
UserReadPermission,
|
||||
UserWritePermission,
|
||||
GroupReadPermission,
|
||||
GroupWritePermission,
|
||||
ReactionReadPermission,
|
||||
ReactionWritePermission,
|
||||
WebhookReadPermission,
|
||||
WebhookWritePermission,
|
||||
SettingsReadPermission,
|
||||
SettingsWritePermission,
|
||||
}
|
||||
}
|
||||
|
||||
func FromJSONArray(ctx context.Context, permissions string) []string {
|
||||
var result []string
|
||||
if err := json.Unmarshal([]byte(permissions), &result); err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to unmarshal permissions", "error", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func ToJSONArray(ctx context.Context, permissions []string) string {
|
||||
if len(permissions) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
|
||||
data, err := json.Marshal(permissions)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to marshal permissions", "error", err)
|
||||
|
||||
return "[]"
|
||||
}
|
||||
|
||||
return string(data)
|
||||
}
|
||||
84
app/auth/permission_test.go
Normal file
84
app/auth/permission_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFromJSONArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JSON array",
|
||||
input: `["ticket:read", "ticket:write"]`,
|
||||
want: []string{"ticket:read", "ticket:write"},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty array",
|
||||
input: "[]",
|
||||
want: []string{},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
input: "not json",
|
||||
want: nil,
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := FromJSONArray(t.Context(), tt.input)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("FromJSONArray() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToJSONArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Valid permissions array",
|
||||
input: []string{"ticket:read", "ticket:write"},
|
||||
want: `["ticket:read","ticket:write"]`,
|
||||
},
|
||||
{
|
||||
name: "Empty array",
|
||||
input: []string{},
|
||||
want: "[]",
|
||||
},
|
||||
{
|
||||
name: "Nil array",
|
||||
input: nil,
|
||||
want: "[]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ToJSONArray(t.Context(), tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("ToJSONArray() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
178
app/auth/resetpassword.go
Normal file
178
app/auth/resetpassword.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth/password"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/mail"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
)
|
||||
|
||||
func handleResetPasswordMail(queries *sqlc.Queries, mailer *mail.Mailer) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
type passwordResetData struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
b, err := json.Marshal(map[string]any{
|
||||
"message": "Password reset email sent when the user exists",
|
||||
})
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to create response: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var data passwordResetData
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
errorJSON(w, http.StatusBadRequest, "Invalid request, missing email field")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
user, err := queries.UserByEmail(r.Context(), &data.Email)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// Do not reveal whether the user exists or not
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(b)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to get user: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := settings.Load(r.Context(), queries)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to load settings: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resetToken, err := createResetToken(&user, settings)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to create reset token: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
link := settings.Meta.AppURL + "/ui/password-reset?mail=" + data.Email + "&token=" + resetToken
|
||||
|
||||
subject := settings.Meta.ResetPasswordTemplate.Subject
|
||||
subject = strings.ReplaceAll(subject, "{APP_NAME}", settings.Meta.AppName)
|
||||
|
||||
plainTextBody := `Hello,
|
||||
Thank you for joining us at {APP_NAME}.
|
||||
Click on the link below to verify your email address or copy the token into the app:
|
||||
|
||||
{ACTION_URL}
|
||||
|
||||
Thanks, {APP_NAME} team`
|
||||
plainTextBody = strings.ReplaceAll(plainTextBody, "{ACTION_URL}", link)
|
||||
plainTextBody = strings.ReplaceAll(plainTextBody, "{APP_NAME}", settings.Meta.AppName)
|
||||
|
||||
htmlBody := settings.Meta.ResetPasswordTemplate.Body
|
||||
htmlBody = strings.ReplaceAll(htmlBody, "{ACTION_URL}", link)
|
||||
htmlBody = strings.ReplaceAll(htmlBody, "{APP_NAME}", settings.Meta.AppName)
|
||||
|
||||
if err := mailer.Send(r.Context(), data.Email, subject, plainTextBody, htmlBody); err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to send password reset email: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
func handlePassword(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
type passwordResetData struct {
|
||||
Token string `json:"token"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
PasswordConfirm string `json:"password_confirm"`
|
||||
}
|
||||
|
||||
var data passwordResetData
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
errorJSON(w, http.StatusBadRequest, "Invalid request, missing email or password fields")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if data.Password != data.PasswordConfirm {
|
||||
errorJSON(w, http.StatusBadRequest, "Passwords do not match")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
user, err := queries.UserByEmail(r.Context(), &data.Email)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
errorJSON(w, http.StatusBadRequest, "Invalid or expired reset token")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to get user: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := settings.Load(r.Context(), queries)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to load settings: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := verifyResetToken(data.Token, &user, settings.Meta.AppURL, settings.RecordPasswordResetToken.Secret); err != nil {
|
||||
errorJSON(w, http.StatusBadRequest, "Invalid or expired reset token: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
passwordHash, tokenKey, err := password.Hash(data.Password)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to hash password: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
if _, err := queries.UpdateUser(r.Context(), sqlc.UpdateUserParams{
|
||||
ID: user.ID,
|
||||
PasswordHash: &passwordHash,
|
||||
TokenKey: &tokenKey,
|
||||
LastResetSentAt: &now,
|
||||
}); err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to update password: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b, err := json.Marshal(map[string]any{
|
||||
"message": "Password reset successfully",
|
||||
})
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to create response: "+err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
}
|
||||
94
app/auth/resettoken_test.go
Normal file
94
app/auth/resettoken_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
)
|
||||
|
||||
func TestService_createResetToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
createUser *sqlc.User
|
||||
tokenDuration time.Duration
|
||||
waitDuration time.Duration
|
||||
verifyUser *sqlc.User
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
args: args{
|
||||
createUser: &sqlc.User{ID: "testuser", Tokenkey: "testtoken"},
|
||||
tokenDuration: time.Hour,
|
||||
waitDuration: 0,
|
||||
verifyUser: &sqlc.User{
|
||||
ID: "testuser",
|
||||
Tokenkey: "testtoken",
|
||||
Updated: mustParse(t, "2006-01-02 15:04:05Z", "2025-06-02 19:18:06.292Z"),
|
||||
},
|
||||
},
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
args: args{
|
||||
createUser: &sqlc.User{ID: "testuser", Tokenkey: "testtoken"},
|
||||
tokenDuration: 0,
|
||||
waitDuration: time.Second,
|
||||
verifyUser: &sqlc.User{
|
||||
ID: "testuser",
|
||||
Tokenkey: "testtoken",
|
||||
Updated: mustParse(t, "2006-01-02 15:04:05Z", "2025-06-02 19:18:06.292Z"),
|
||||
},
|
||||
},
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
args: args{
|
||||
createUser: &sqlc.User{ID: "testuser", Tokenkey: "testtoken"},
|
||||
tokenDuration: time.Hour,
|
||||
waitDuration: 0,
|
||||
verifyUser: &sqlc.User{
|
||||
ID: "invaliduser",
|
||||
Tokenkey: "invalidtoken",
|
||||
Updated: mustParse(t, "2006-01-02 15:04:05Z", "2025-06-02 19:18:06.292Z"),
|
||||
},
|
||||
},
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := createResetTokenWithDuration(tt.args.createUser, "", "", tt.args.tokenDuration)
|
||||
require.NoError(t, err, "createResetToken()")
|
||||
|
||||
time.Sleep(tt.args.waitDuration)
|
||||
|
||||
err = verifyResetToken(got, tt.args.verifyUser, "", "")
|
||||
tt.wantErr(t, err, "verifyResetToken()")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParse(t *testing.T, layout, value string) time.Time {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := time.Parse(layout, value)
|
||||
require.NoError(t, err, "mustParse()")
|
||||
|
||||
return parsed
|
||||
}
|
||||
58
app/auth/server.go
Normal file
58
app/auth/server.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/mail"
|
||||
)
|
||||
|
||||
func Server(queries *sqlc.Queries, mailer *mail.Mailer) http.Handler {
|
||||
router := chi.NewRouter()
|
||||
|
||||
router.Get("/user", handleUser(queries))
|
||||
router.Post("/local/login", handleLogin(queries))
|
||||
router.Post("/local/reset-password-mail", handleResetPasswordMail(queries, mailer))
|
||||
router.Post("/local/reset-password", handlePassword(queries))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func handleUser(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
authorizationHeader := r.Header.Get("Authorization")
|
||||
bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix)
|
||||
|
||||
user, _, err := verifyAccessToken(r.Context(), bearerToken, queries)
|
||||
if err != nil {
|
||||
_, _ = w.Write([]byte("null"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
permissions, err := queries.ListUserPermissions(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b, err := json.Marshal(map[string]any{
|
||||
"user": user,
|
||||
"permissions": permissions,
|
||||
})
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
}
|
||||
99
app/auth/server_local.go
Normal file
99
app/auth/server_local.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
)
|
||||
|
||||
var ErrUserInactive = errors.New("user is inactive")
|
||||
|
||||
func handleLogin(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
type loginData struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var data loginData
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
unauthorizedJSON(w, "Invalid request")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
user, err := loginWithMail(r.Context(), data.Email, data.Password, queries)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserInactive) {
|
||||
unauthorizedJSON(w, "User is inactive")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
unauthorizedJSON(w, "Login failed")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
permissions, err := queries.ListUserPermissions(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to get user permissions")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := settings.Load(r.Context(), queries)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to load settings")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Duration(settings.RecordAuthToken.Duration) * time.Second
|
||||
|
||||
token, err := CreateAccessToken(r.Context(), user, permissions, duration, queries)
|
||||
if err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to create login token")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := map[string]string{
|
||||
"token": token,
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
errorJSON(w, http.StatusInternalServerError, "Failed to encode response")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loginWithMail(ctx context.Context, mail, password string, queries *sqlc.Queries) (*sqlc.User, error) {
|
||||
user, err := queries.UserByEmail(ctx, &mail)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find user by email %q: %w", mail, err)
|
||||
}
|
||||
|
||||
if !user.Active {
|
||||
return nil, ErrUserInactive
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Passwordhash), []byte(password)); err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
225
app/auth/token.go
Normal file
225
app/auth/token.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/settings"
|
||||
)
|
||||
|
||||
const (
|
||||
purposeAccess = "access"
|
||||
purposeReset = "reset"
|
||||
scopeReset = "reset"
|
||||
)
|
||||
|
||||
func CreateAccessToken(ctx context.Context, user *sqlc.User, permissions []string, duration time.Duration, queries *sqlc.Queries) (string, error) {
|
||||
settings, err := settings.Load(ctx, queries)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load settings: %w", err)
|
||||
}
|
||||
|
||||
return createToken(user, duration, purposeAccess, permissions, settings.Meta.AppURL, settings.RecordAuthToken.Secret)
|
||||
}
|
||||
|
||||
func createResetToken(user *sqlc.User, settings *settings.Settings) (string, error) {
|
||||
duration := time.Duration(settings.RecordPasswordResetToken.Duration) * time.Second
|
||||
|
||||
return createResetTokenWithDuration(user, settings.Meta.AppURL, settings.RecordPasswordResetToken.Secret, duration)
|
||||
}
|
||||
|
||||
func createResetTokenWithDuration(user *sqlc.User, url, appToken string, duration time.Duration) (string, error) {
|
||||
return createToken(user, duration, purposeReset, []string{scopeReset}, url, appToken)
|
||||
}
|
||||
|
||||
func createToken(user *sqlc.User, duration time.Duration, purpose string, scopes []string, url, appToken string) (string, error) {
|
||||
if scopes == nil {
|
||||
scopes = []string{}
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"exp": time.Now().Add(duration).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"iss": url,
|
||||
"purpose": purpose,
|
||||
"scopes": scopes,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
signingKey := user.Tokenkey + appToken
|
||||
|
||||
return token.SignedString([]byte(signingKey))
|
||||
}
|
||||
|
||||
func verifyToken(tokenStr string, user *sqlc.User, url, appToken string) (jwt.MapClaims, error) { //nolint:cyclop
|
||||
signingKey := user.Tokenkey + appToken
|
||||
|
||||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected algorithm: %v", t.Header["alg"])
|
||||
}
|
||||
|
||||
return []byte(signingKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify token: %w", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("token invalid")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
iss, err := claims.GetIssuer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issuer: %w", err)
|
||||
}
|
||||
|
||||
if iss != url {
|
||||
return nil, fmt.Errorf("token issued by a different server")
|
||||
}
|
||||
|
||||
sub, err := claims.GetSubject()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get subject: %w", err)
|
||||
}
|
||||
|
||||
if sub != user.ID {
|
||||
return nil, fmt.Errorf("token belongs to a different user")
|
||||
}
|
||||
|
||||
iat, err := claims.GetExpirationTime()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get expiration time: %w", err)
|
||||
}
|
||||
|
||||
if iat.Before(time.Now()) {
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func verifyAccessToken(ctx context.Context, bearerToken string, queries *sqlc.Queries) (*sqlc.User, jwt.MapClaims, error) {
|
||||
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("failed to parse token claims")
|
||||
}
|
||||
|
||||
sub, err := claims.GetSubject()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("token invalid: %w", err)
|
||||
}
|
||||
|
||||
user, err := queries.GetUser(ctx, sub)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to retrieve user for subject %s: %w", sub, err)
|
||||
}
|
||||
|
||||
settings, err := settings.Load(ctx, queries)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to load settings: %w", err)
|
||||
}
|
||||
|
||||
claims, err = verifyToken(bearerToken, &user, settings.Meta.AppURL, settings.RecordAuthToken.Secret)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to verify token: %w", err)
|
||||
}
|
||||
|
||||
if err := hasPurpose(claims, purposeAccess); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to check scopes: %w", err)
|
||||
}
|
||||
|
||||
return &user, claims, nil
|
||||
}
|
||||
|
||||
func verifyResetToken(tokenStr string, user *sqlc.User, url, appToken string) error {
|
||||
claims, err := verifyToken(tokenStr, user, url, appToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iat, err := claims.GetIssuedAt()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get issued at: %w", err)
|
||||
}
|
||||
|
||||
lastUpdated := user.Updated // TODO: create a last reset at column
|
||||
|
||||
if iat.Before(lastUpdated) {
|
||||
return fmt.Errorf("token already used")
|
||||
}
|
||||
|
||||
if err := hasPurpose(claims, purposeReset); err != nil {
|
||||
return fmt.Errorf("failed to check scopes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasPurpose(claim jwt.MapClaims, expectedPurpose string) error {
|
||||
purpose, err := purpose(claim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get purposes: %w", err)
|
||||
}
|
||||
|
||||
if purpose != expectedPurpose {
|
||||
return fmt.Errorf("token has wrong purpose: %s, expected: %s", purpose, expectedPurpose)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func purpose(claim jwt.MapClaims) (string, error) {
|
||||
purposeClaim, ok := claim["purpose"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no purpose found")
|
||||
}
|
||||
|
||||
purpose, ok := purposeClaim.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid purpose type")
|
||||
}
|
||||
|
||||
return purpose, nil
|
||||
}
|
||||
|
||||
func scopes(claim jwt.MapClaims) ([]string, error) {
|
||||
scopesClaim, ok := claim["scopes"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no scopes found")
|
||||
}
|
||||
|
||||
scopesSlice, ok := scopesClaim.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid scopes claim type: %T", scopesClaim)
|
||||
}
|
||||
|
||||
scopes := make([]string, 0, len(scopesSlice))
|
||||
|
||||
for _, scope := range scopesSlice {
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid scope claim element type: %T", scope)
|
||||
}
|
||||
|
||||
scopes = append(scopes, scopeStr)
|
||||
}
|
||||
|
||||
return scopes, nil
|
||||
}
|
||||
46
app/auth/usercontext/usercontext.go
Normal file
46
app/auth/usercontext/usercontext.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package usercontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
)
|
||||
|
||||
type userKey struct{}
|
||||
|
||||
func UserRequest(r *http.Request, user *sqlc.User) *http.Request {
|
||||
return r.WithContext(UserContext(r.Context(), user))
|
||||
}
|
||||
|
||||
func UserContext(ctx context.Context, user *sqlc.User) context.Context {
|
||||
return context.WithValue(ctx, userKey{}, user)
|
||||
}
|
||||
|
||||
func UserFromContext(ctx context.Context) (*sqlc.User, bool) {
|
||||
user, ok := ctx.Value(userKey{}).(*sqlc.User)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return user, true
|
||||
}
|
||||
|
||||
type permissionKey struct{}
|
||||
|
||||
func PermissionRequest(r *http.Request, permissions []string) *http.Request {
|
||||
return r.WithContext(PermissionContext(r.Context(), permissions))
|
||||
}
|
||||
|
||||
func PermissionContext(ctx context.Context, permissions []string) context.Context {
|
||||
return context.WithValue(ctx, permissionKey{}, permissions)
|
||||
}
|
||||
|
||||
func PermissionFromContext(ctx context.Context) ([]string, bool) {
|
||||
permissions, ok := ctx.Value(permissionKey{}).([]string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return permissions, true
|
||||
}
|
||||
116
app/auth/usercontext/usercontext_test.go
Normal file
116
app/auth/usercontext/usercontext_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package usercontext
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
)
|
||||
|
||||
func TestPermissionContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *sqlc.User
|
||||
permissions []string
|
||||
wantPerms []string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "Set and get permissions",
|
||||
permissions: []string{"ticket:read", "ticket:write"},
|
||||
wantPerms: []string{"ticket:read", "ticket:write"},
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "No permissions set",
|
||||
wantPerms: nil,
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test context functions
|
||||
ctx := PermissionContext(t.Context(), tt.permissions)
|
||||
gotPerms, gotOk := PermissionFromContext(ctx)
|
||||
|
||||
if !reflect.DeepEqual(gotPerms, tt.wantPerms) {
|
||||
t.Errorf("PermissionFromContext() got perms = %v, want %v", gotPerms, tt.wantPerms)
|
||||
}
|
||||
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("PermissionFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
|
||||
// Test request functions
|
||||
req := &http.Request{}
|
||||
req = PermissionRequest(req, tt.permissions)
|
||||
gotPerms, gotOk = PermissionFromContext(req.Context())
|
||||
|
||||
if !reflect.DeepEqual(gotPerms, tt.wantPerms) {
|
||||
t.Errorf("PermissionFromContext() got perms = %v, want %v", gotPerms, tt.wantPerms)
|
||||
}
|
||||
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("PermissionFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *sqlc.User
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "Set and get user",
|
||||
user: &sqlc.User{ID: "test-user"},
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "No user set",
|
||||
user: nil,
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test context functions
|
||||
ctx := UserContext(t.Context(), tt.user)
|
||||
gotUser, gotOk := UserFromContext(ctx)
|
||||
|
||||
if !reflect.DeepEqual(gotUser, tt.user) {
|
||||
t.Errorf("UserFromContext() got user = %v, want %v", gotUser, tt.user)
|
||||
}
|
||||
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("UserFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
|
||||
// Test request functions
|
||||
req := &http.Request{}
|
||||
req = UserRequest(req, tt.user)
|
||||
gotUser, gotOk = UserFromContext(req.Context())
|
||||
|
||||
if !reflect.DeepEqual(gotUser, tt.user) {
|
||||
t.Errorf("UserFromContext() got user = %v, want %v", gotUser, tt.user)
|
||||
}
|
||||
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("UserFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user