mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-07 07:42:45 +01:00
48
auth.go
48
auth.go
@@ -2,15 +2,16 @@ package catalyst
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/database"
|
||||
@@ -43,6 +44,7 @@ func (c *AuthConfig) Verifier(ctx context.Context) (*oidc.IDTokenVerifier, error
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}), nil
|
||||
}
|
||||
|
||||
@@ -81,12 +83,14 @@ func bearerAuth(db *database.Database, authHeader string, iss string, config *Au
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
api.JSONErrorStatus(w, http.StatusUnauthorized, errors.New("no bearer token"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
claims, apiError := verifyClaims(r, config, authHeader[7:])
|
||||
if apiError != nil {
|
||||
api.JSONErrorStatus(w, apiError.Status, apiError.Internal)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -100,6 +104,7 @@ func bearerAuth(db *database.Database, authHeader string, iss string, config *Au
|
||||
r, err := setContextClaims(r, db, claims, config)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,6 +121,7 @@ func keyAuth(db *database.Database, keyHeader string) func(next http.Handler) ht
|
||||
key, err := db.UserByHash(r.Context(), h)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not verify private token: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -132,16 +138,19 @@ func sessionAuth(db *database.Database, config *AuthConfig) func(next http.Handl
|
||||
claims, noCookie, err := claimsCookie(r)
|
||||
if err != nil {
|
||||
api.JSONError(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
if noCookie {
|
||||
redirectToLogin(w, r, config.OAuth2)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
r, err = setContextClaims(r, db, claims, config)
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("could not load user: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +159,7 @@ func sessionAuth(db *database.Database, config *AuthConfig) func(next http.Handl
|
||||
}
|
||||
}
|
||||
|
||||
func setContextClaims(r *http.Request, db *database.Database, claims map[string]interface{}, config *AuthConfig) (*http.Request, error) {
|
||||
func setContextClaims(r *http.Request, db *database.Database, claims map[string]any, config *AuthConfig) (*http.Request, error) {
|
||||
newUser, newSetting, err := mapUserAndSettings(claims, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -182,7 +191,7 @@ func setContextUser(r *http.Request, user *model.UserResponse, hooks *hooks.Hook
|
||||
return busdb.SetContext(r, user)
|
||||
}
|
||||
|
||||
func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*model.UserForm, *model.UserData, error) {
|
||||
func mapUserAndSettings(claims map[string]any, config *AuthConfig) (*model.UserForm, *model.UserData, error) {
|
||||
// handle Bearer tokens
|
||||
// if typ, ok := claims["typ"]; ok && typ == "Bearer" {
|
||||
// return &model.User{
|
||||
@@ -208,8 +217,8 @@ func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*mod
|
||||
name = ""
|
||||
}
|
||||
|
||||
var roles = role.Strings(config.AuthDefaultRoles)
|
||||
if contains(config.AuthAdminUsers, username) {
|
||||
roles := role.Strings(config.AuthDefaultRoles)
|
||||
if slices.Contains(config.AuthAdminUsers, username) {
|
||||
roles = append(roles, role.Admin)
|
||||
}
|
||||
|
||||
@@ -223,20 +232,12 @@ func mapUserAndSettings(claims map[string]interface{}, config *AuthConfig) (*mod
|
||||
}, nil
|
||||
}
|
||||
|
||||
func contains(l []string, s string) bool {
|
||||
for _, e := range l {
|
||||
if e == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) (string, error) {
|
||||
func getString(m map[string]any, key string) (string, error) {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("mapping of %s failed, wrong type (%T)", key, v)
|
||||
}
|
||||
|
||||
@@ -247,12 +248,14 @@ func redirectToLogin(w http.ResponseWriter, r *http.Request, oauth2Config *oauth
|
||||
state, err := state()
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("generating state failed"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
setStateCookie(w, state)
|
||||
|
||||
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -262,11 +265,13 @@ func AuthorizeBlockedUser() func(http.Handler) http.Handler {
|
||||
user, ok := busdb.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if user.Blocked {
|
||||
api.JSONErrorStatus(w, http.StatusForbidden, errors.New("user is blocked"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -281,11 +286,13 @@ func AuthorizeRole(roles []string) func(http.Handler) http.Handler {
|
||||
user, ok := busdb.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("no user in context"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !role.UserHasRoles(user, role.FromStrings(roles)) {
|
||||
api.JSONErrorStatus(w, http.StatusForbidden, fmt.Errorf("missing role %s has %s", roles, user.Roles))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -299,17 +306,20 @@ func callback(config *AuthConfig) http.HandlerFunc {
|
||||
state, err := stateCookie(r)
|
||||
if err != nil || state == "" {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state missing"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if state != r.URL.Query().Get("state") {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("state mismatch"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := config.OAuth2.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, fmt.Errorf("oauth2 exchange failed: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -317,12 +327,14 @@ func callback(config *AuthConfig) http.HandlerFunc {
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
api.JSONErrorStatus(w, http.StatusInternalServerError, errors.New("missing id token"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
claims, apiError := verifyClaims(r, config, rawIDToken)
|
||||
if apiError != nil {
|
||||
api.JSONErrorStatus(w, apiError.Status, apiError.Internal)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -337,10 +349,11 @@ func state() (string, error) {
|
||||
if _, err := rand.Read(rnd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(rnd), nil
|
||||
}
|
||||
|
||||
func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[string]interface{}, *api.HTTPError) {
|
||||
func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[string]any, *api.HTTPError) {
|
||||
verifier, err := config.Verifier(r.Context())
|
||||
if err != nil {
|
||||
return nil, &api.HTTPError{Status: http.StatusUnauthorized, Internal: fmt.Errorf("could not verify: %w", err)}
|
||||
@@ -350,9 +363,10 @@ func verifyClaims(r *http.Request, config *AuthConfig, rawIDToken string) (map[s
|
||||
return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("could not verify bearer token: %w", err)}
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := authToken.Claims(&claims); err != nil {
|
||||
return nil, &api.HTTPError{Status: http.StatusInternalServerError, Internal: fmt.Errorf("failed to parse claims: %w", err)}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user