mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-07 07:42:45 +01:00
refactor: remove pocketbase (#1138)
This commit is contained in:
158
app/auth/middleware.go
Normal file
158
app/auth/middleware.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
strictnethttp "github.com/oapi-codegen/runtime/strictmiddleware/nethttp"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/openapi"
|
||||
)
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
|
||||
func Middleware(queries *sqlc.Queries) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/config" {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := r.Header.Get("Authorization")
|
||||
bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix)
|
||||
|
||||
user, claims, err := verifyAccessToken(r.Context(), bearerToken, queries)
|
||||
if err != nil {
|
||||
slog.ErrorContext(r.Context(), "invalid bearer token", "error", err)
|
||||
|
||||
unauthorizedJSON(w, "invalid bearer token")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
scopes, err := scopes(claims)
|
||||
if err != nil {
|
||||
slog.ErrorContext(r.Context(), "failed to get scopes from token", "error", err)
|
||||
|
||||
unauthorizedJSON(w, "failed to get scopes")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Set the user in the context
|
||||
r = usercontext.UserRequest(r, user)
|
||||
r = usercontext.PermissionRequest(r, scopes)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateFileScopes(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requiredScopes := []string{"file:read"}
|
||||
if slices.Contains([]string{http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete}, r.Method) {
|
||||
requiredScopes = []string{"file:write"}
|
||||
}
|
||||
|
||||
if err := validateScopes(r.Context(), requiredScopes); err != nil {
|
||||
slog.ErrorContext(r.Context(), "failed to validate scopes", "error", err)
|
||||
unauthorizedJSON(w, "missing required scopes")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func ValidateScopesStrict(next strictnethttp.StrictHTTPHandlerFunc, _ string) strictnethttp.StrictHTTPHandlerFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
|
||||
requiredScopes, err := requiredScopes(ctx)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "failed to get required scopes", "error", err)
|
||||
unauthorizedJSON(w, "failed to get required scopes")
|
||||
|
||||
return nil, fmt.Errorf("failed to get required scopes: %w", err)
|
||||
}
|
||||
|
||||
if err := validateScopes(ctx, requiredScopes); err != nil {
|
||||
slog.ErrorContext(ctx, "failed to validate scopes", "error", err)
|
||||
unauthorizedJSON(w, "missing required scopes")
|
||||
|
||||
return nil, fmt.Errorf("missing required scopes: %w", err)
|
||||
}
|
||||
|
||||
return next(ctx, w, r, request)
|
||||
}
|
||||
}
|
||||
|
||||
func LogError(next strictnethttp.StrictHTTPHandlerFunc, _ string) strictnethttp.StrictHTTPHandlerFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
|
||||
re, err := next(ctx, w, r, request)
|
||||
if err != nil {
|
||||
if err.Error() == "context canceled" {
|
||||
// This is a common error when the request is canceled, e.g., by the client.
|
||||
// We can ignore this error as it does not indicate a problem with the handler.
|
||||
return re, nil
|
||||
}
|
||||
|
||||
slog.ErrorContext(ctx, "handler error", "error", err, "method", r.Method, "path", r.URL.Path)
|
||||
}
|
||||
|
||||
return re, err
|
||||
}
|
||||
}
|
||||
|
||||
func validateScopes(ctx context.Context, requiredScopes []string) error {
|
||||
if len(requiredScopes) > 0 {
|
||||
permissions, ok := usercontext.PermissionFromContext(ctx)
|
||||
if !ok {
|
||||
return errors.New("missing permissions")
|
||||
}
|
||||
|
||||
if !hasScope(permissions, requiredScopes) {
|
||||
return fmt.Errorf("missing required scopes: %v", requiredScopes)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func requiredScopes(ctx context.Context) ([]string, error) {
|
||||
requiredScopesValue := ctx.Value(openapi.OAuth2Scopes)
|
||||
if requiredScopesValue == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
requiredScopes, ok := requiredScopesValue.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid required scopes type: %T", requiredScopesValue)
|
||||
}
|
||||
|
||||
return requiredScopes, nil
|
||||
}
|
||||
|
||||
func hasScope(scopes []string, requiredScopes []string) bool {
|
||||
if slices.Contains(scopes, "admin") {
|
||||
// If the user has admin scope, they can access everything
|
||||
return true
|
||||
}
|
||||
|
||||
for _, s := range requiredScopes {
|
||||
if !slices.Contains(scopes, s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user