refactor: remove pocketbase (#1138)

This commit is contained in:
Jonas Plum
2025-09-02 21:58:08 +02:00
committed by GitHub
parent f28c238135
commit eba2615ec0
435 changed files with 42677 additions and 4730 deletions

16
app/auth/errorjson.go Normal file
View File

@@ -0,0 +1,16 @@
package auth
import (
"fmt"
"net/http"
)
func unauthorizedJSON(w http.ResponseWriter, msg string) {
errorJSON(w, http.StatusUnauthorized, msg)
}
func errorJSON(w http.ResponseWriter, status int, msg string) {
w.WriteHeader(status)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
_, _ = fmt.Fprintf(w, `{"status": %d, "error": %q, "message": %q}`, status, http.StatusText(status), msg)
}

158
app/auth/middleware.go Normal file
View File

@@ -0,0 +1,158 @@
package auth
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"slices"
"strings"
strictnethttp "github.com/oapi-codegen/runtime/strictmiddleware/nethttp"
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
"github.com/SecurityBrewery/catalyst/app/openapi"
)
const bearerPrefix = "Bearer "
func Middleware(queries *sqlc.Queries) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/config" {
next.ServeHTTP(w, r)
return
}
authorizationHeader := r.Header.Get("Authorization")
bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix)
user, claims, err := verifyAccessToken(r.Context(), bearerToken, queries)
if err != nil {
slog.ErrorContext(r.Context(), "invalid bearer token", "error", err)
unauthorizedJSON(w, "invalid bearer token")
return
}
scopes, err := scopes(claims)
if err != nil {
slog.ErrorContext(r.Context(), "failed to get scopes from token", "error", err)
unauthorizedJSON(w, "failed to get scopes")
return
}
// Set the user in the context
r = usercontext.UserRequest(r, user)
r = usercontext.PermissionRequest(r, scopes)
next.ServeHTTP(w, r)
})
}
}
func ValidateFileScopes(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requiredScopes := []string{"file:read"}
if slices.Contains([]string{http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete}, r.Method) {
requiredScopes = []string{"file:write"}
}
if err := validateScopes(r.Context(), requiredScopes); err != nil {
slog.ErrorContext(r.Context(), "failed to validate scopes", "error", err)
unauthorizedJSON(w, "missing required scopes")
return
}
next.ServeHTTP(w, r)
})
}
func ValidateScopesStrict(next strictnethttp.StrictHTTPHandlerFunc, _ string) strictnethttp.StrictHTTPHandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
requiredScopes, err := requiredScopes(ctx)
if err != nil {
slog.ErrorContext(ctx, "failed to get required scopes", "error", err)
unauthorizedJSON(w, "failed to get required scopes")
return nil, fmt.Errorf("failed to get required scopes: %w", err)
}
if err := validateScopes(ctx, requiredScopes); err != nil {
slog.ErrorContext(ctx, "failed to validate scopes", "error", err)
unauthorizedJSON(w, "missing required scopes")
return nil, fmt.Errorf("missing required scopes: %w", err)
}
return next(ctx, w, r, request)
}
}
func LogError(next strictnethttp.StrictHTTPHandlerFunc, _ string) strictnethttp.StrictHTTPHandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
re, err := next(ctx, w, r, request)
if err != nil {
if err.Error() == "context canceled" {
// This is a common error when the request is canceled, e.g., by the client.
// We can ignore this error as it does not indicate a problem with the handler.
return re, nil
}
slog.ErrorContext(ctx, "handler error", "error", err, "method", r.Method, "path", r.URL.Path)
}
return re, err
}
}
func validateScopes(ctx context.Context, requiredScopes []string) error {
if len(requiredScopes) > 0 {
permissions, ok := usercontext.PermissionFromContext(ctx)
if !ok {
return errors.New("missing permissions")
}
if !hasScope(permissions, requiredScopes) {
return fmt.Errorf("missing required scopes: %v", requiredScopes)
}
}
return nil
}
func requiredScopes(ctx context.Context) ([]string, error) {
requiredScopesValue := ctx.Value(openapi.OAuth2Scopes)
if requiredScopesValue == nil {
return nil, nil
}
requiredScopes, ok := requiredScopesValue.([]string)
if !ok {
return nil, fmt.Errorf("invalid required scopes type: %T", requiredScopesValue)
}
return requiredScopes, nil
}
func hasScope(scopes []string, requiredScopes []string) bool {
if slices.Contains(scopes, "admin") {
// If the user has admin scope, they can access everything
return true
}
for _, s := range requiredScopes {
if !slices.Contains(scopes, s) {
return false
}
}
return true
}

188
app/auth/middleware_test.go Normal file
View File

@@ -0,0 +1,188 @@
package auth
import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/SecurityBrewery/catalyst/app/auth/usercontext"
"github.com/SecurityBrewery/catalyst/app/openapi"
)
func mockHandler(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(`{"message":"OK"}`))
}
func TestService_ValidateScopes(t *testing.T) {
t.Parallel()
type args struct {
requiredScopes []string
permissions []string
next http.HandlerFunc
}
tests := []struct {
name string
args args
want httptest.ResponseRecorder
}{
{
name: "no scopes",
args: args{
requiredScopes: []string{"user:read"},
permissions: []string{},
next: mockHandler,
},
want: httptest.ResponseRecorder{
Code: http.StatusUnauthorized,
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
},
},
{
name: "insufficient scopes",
args: args{
requiredScopes: []string{"user:write"},
permissions: []string{"user:read"},
next: mockHandler,
},
want: httptest.ResponseRecorder{
Code: http.StatusUnauthorized,
Body: bytes.NewBufferString(`{"error": "Unauthorized", "message": "missing required scopes", "status": 401}`),
},
},
{
name: "sufficient scopes",
args: args{
requiredScopes: []string{"user:read"},
permissions: []string{"user:read", "user:write"},
next: mockHandler,
},
want: httptest.ResponseRecorder{
Code: http.StatusOK,
Body: bytes.NewBufferString(`{"message":"OK"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := ValidateScopesStrict(func(_ context.Context, w http.ResponseWriter, r *http.Request, _ any) (response any, err error) {
tt.args.next(w, r)
return w, nil
}, "")
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
//nolint: staticcheck
r = r.WithContext(context.WithValue(r.Context(), openapi.OAuth2Scopes, tt.args.requiredScopes))
r = usercontext.PermissionRequest(r, tt.args.permissions)
if _, err := handler(r.Context(), w, r, r); err != nil {
return
}
assert.Equal(t, tt.want.Code, w.Code, "response code should match expected value")
assert.JSONEq(t, tt.want.Body.String(), w.Body.String(), "response body should match expected value")
})
}
}
func Test_hasScope(t *testing.T) {
t.Parallel()
type args struct {
scopes []string
requiredScopes []string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "no scopes",
args: args{
scopes: []string{},
requiredScopes: []string{"user:read"},
},
want: false,
},
{
name: "missing required scope",
args: args{
scopes: []string{"user:read"},
requiredScopes: []string{"user:write"},
},
},
{
name: "has required scope",
args: args{
scopes: []string{"user:read", "user:write"},
requiredScopes: []string{"user:read"},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equalf(t, tt.want, hasScope(tt.args.scopes, tt.args.requiredScopes), "hasScope(%v, %v)", tt.args.scopes, tt.args.requiredScopes)
})
}
}
func Test_requiredScopes(t *testing.T) {
t.Parallel()
type args struct {
r *http.Request
}
tests := []struct {
name string
args args
want []string
wantErr assert.ErrorAssertionFunc
}{
{
name: "no required scopes",
args: args{
r: httptest.NewRequest(http.MethodGet, "/", nil),
},
want: nil,
wantErr: assert.NoError,
},
{
name: "valid required scopes",
args: args{
//nolint: staticcheck
r: httptest.NewRequest(http.MethodGet, "/", nil).WithContext(context.WithValue(t.Context(), openapi.OAuth2Scopes, []string{"user:read", "user:write"})),
},
want: []string{"user:read", "user:write"},
wantErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := requiredScopes(tt.args.r.Context())
if !tt.wantErr(t, err, fmt.Sprintf("requiredScopes(%v)", tt.args.r)) {
return
}
assert.Equalf(t, tt.want, got, "requiredScopes(%v)", tt.args.r)
})
}
}

View File

@@ -0,0 +1,32 @@
package password
import (
"crypto/rand"
"encoding/base64"
"fmt"
"golang.org/x/crypto/bcrypt"
)
func Hash(password string) (hashedPassword, tokenKey string, err error) {
hashedPasswordB, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", "", fmt.Errorf("failed to hash password: %w", err)
}
tokenKey, err = GenerateTokenKey()
if err != nil {
return "", "", err
}
return string(hashedPasswordB), tokenKey, nil
}
func GenerateTokenKey() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}

View File

@@ -0,0 +1,67 @@
package password
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
)
func TestHash(t *testing.T) {
t.Parallel()
type args struct {
password string
}
tests := []struct {
name string
args args
wantErr require.ErrorAssertionFunc
}{
{
name: "Hash valid password",
args: args{
password: "securePassword123!",
},
wantErr: require.NoError,
},
{
name: "Long password",
args: args{
password: strings.Repeat("a", 75),
},
wantErr: require.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHashedPassword, gotTokenKey, err := Hash(tt.args.password)
tt.wantErr(t, err, "Hash() should not return an error")
if err != nil {
return
}
assert.NotEmpty(t, gotHashedPassword, "Hash() gotHashedPassword should not be empty")
assert.NotEmpty(t, gotTokenKey, "Hash() gotTokenKey should not be empty")
require.NoError(t, bcrypt.CompareHashAndPassword([]byte(gotHashedPassword), []byte(tt.args.password)), "Hash() hashed password does not match original password")
assert.GreaterOrEqual(t, len(gotTokenKey), 43, "Hash() gotTokenKey should be at least 43 characters long")
})
}
}
func TestGenerateTokenKey(t *testing.T) {
t.Parallel()
tokenKey, err := GenerateTokenKey()
require.NoError(t, err)
assert.NotEmpty(t, tokenKey, "GenerateTokenKey() tokenKey should not be empty")
assert.GreaterOrEqual(t, len(tokenKey), 43, "GenerateTokenKey() tokenKey should be at least 43 characters long")
}

73
app/auth/permission.go Normal file
View File

@@ -0,0 +1,73 @@
package auth
import (
"context"
"encoding/json"
"log/slog"
)
var (
TicketReadPermission = "ticket:read"
TicketWritePermission = "ticket:write"
FileReadPermission = "file:read"
FileWritePermission = "file:write"
TypeReadPermission = "type:read"
TypeWritePermission = "type:write"
UserReadPermission = "user:read"
UserWritePermission = "user:write"
GroupReadPermission = "group:read"
GroupWritePermission = "group:write"
ReactionReadPermission = "reaction:read"
ReactionWritePermission = "reaction:write"
WebhookReadPermission = "webhook:read"
WebhookWritePermission = "webhook:write"
SettingsReadPermission = "settings:read"
SettingsWritePermission = "settings:write"
)
func All() []string {
return []string{
TicketReadPermission,
TicketWritePermission,
FileReadPermission,
FileWritePermission,
TypeReadPermission,
TypeWritePermission,
UserReadPermission,
UserWritePermission,
GroupReadPermission,
GroupWritePermission,
ReactionReadPermission,
ReactionWritePermission,
WebhookReadPermission,
WebhookWritePermission,
SettingsReadPermission,
SettingsWritePermission,
}
}
func FromJSONArray(ctx context.Context, permissions string) []string {
var result []string
if err := json.Unmarshal([]byte(permissions), &result); err != nil {
slog.ErrorContext(ctx, "Failed to unmarshal permissions", "error", err)
return nil
}
return result
}
func ToJSONArray(ctx context.Context, permissions []string) string {
if len(permissions) == 0 {
return "[]"
}
data, err := json.Marshal(permissions)
if err != nil {
slog.ErrorContext(ctx, "Failed to marshal permissions", "error", err)
return "[]"
}
return string(data)
}

View File

@@ -0,0 +1,84 @@
package auth
import (
"reflect"
"testing"
)
func TestFromJSONArray(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want []string
shouldError bool
}{
{
name: "Valid JSON array",
input: `["ticket:read", "ticket:write"]`,
want: []string{"ticket:read", "ticket:write"},
shouldError: false,
},
{
name: "Empty array",
input: "[]",
want: []string{},
shouldError: false,
},
{
name: "Invalid JSON",
input: "not json",
want: nil,
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := FromJSONArray(t.Context(), tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("FromJSONArray() = %v, want %v", got, tt.want)
}
})
}
}
func TestToJSONArray(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []string
want string
}{
{
name: "Valid permissions array",
input: []string{"ticket:read", "ticket:write"},
want: `["ticket:read","ticket:write"]`,
},
{
name: "Empty array",
input: []string{},
want: "[]",
},
{
name: "Nil array",
input: nil,
want: "[]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := ToJSONArray(t.Context(), tt.input)
if got != tt.want {
t.Errorf("ToJSONArray() = %v, want %v", got, tt.want)
}
})
}
}

178
app/auth/resetpassword.go Normal file
View File

@@ -0,0 +1,178 @@
package auth
import (
"database/sql"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/SecurityBrewery/catalyst/app/auth/password"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
"github.com/SecurityBrewery/catalyst/app/mail"
"github.com/SecurityBrewery/catalyst/app/settings"
)
func handleResetPasswordMail(queries *sqlc.Queries, mailer *mail.Mailer) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
type passwordResetData struct {
Email string `json:"email"`
}
b, err := json.Marshal(map[string]any{
"message": "Password reset email sent when the user exists",
})
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to create response: "+err.Error())
return
}
var data passwordResetData
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
errorJSON(w, http.StatusBadRequest, "Invalid request, missing email field")
return
}
user, err := queries.UserByEmail(r.Context(), &data.Email)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// Do not reveal whether the user exists or not
w.WriteHeader(http.StatusOK)
_, _ = w.Write(b)
return
}
errorJSON(w, http.StatusInternalServerError, "Failed to get user: "+err.Error())
return
}
settings, err := settings.Load(r.Context(), queries)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to load settings: "+err.Error())
return
}
resetToken, err := createResetToken(&user, settings)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to create reset token: "+err.Error())
return
}
link := settings.Meta.AppURL + "/ui/password-reset?mail=" + data.Email + "&token=" + resetToken
subject := settings.Meta.ResetPasswordTemplate.Subject
subject = strings.ReplaceAll(subject, "{APP_NAME}", settings.Meta.AppName)
plainTextBody := `Hello,
Thank you for joining us at {APP_NAME}.
Click on the link below to verify your email address or copy the token into the app:
{ACTION_URL}
Thanks, {APP_NAME} team`
plainTextBody = strings.ReplaceAll(plainTextBody, "{ACTION_URL}", link)
plainTextBody = strings.ReplaceAll(plainTextBody, "{APP_NAME}", settings.Meta.AppName)
htmlBody := settings.Meta.ResetPasswordTemplate.Body
htmlBody = strings.ReplaceAll(htmlBody, "{ACTION_URL}", link)
htmlBody = strings.ReplaceAll(htmlBody, "{APP_NAME}", settings.Meta.AppName)
if err := mailer.Send(r.Context(), data.Email, subject, plainTextBody, htmlBody); err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to send password reset email: "+err.Error())
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(b)
}
}
func handlePassword(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
type passwordResetData struct {
Token string `json:"token"`
Email string `json:"email"`
Password string `json:"password"`
PasswordConfirm string `json:"password_confirm"`
}
var data passwordResetData
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
errorJSON(w, http.StatusBadRequest, "Invalid request, missing email or password fields")
return
}
if data.Password != data.PasswordConfirm {
errorJSON(w, http.StatusBadRequest, "Passwords do not match")
return
}
user, err := queries.UserByEmail(r.Context(), &data.Email)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
errorJSON(w, http.StatusBadRequest, "Invalid or expired reset token")
return
}
errorJSON(w, http.StatusInternalServerError, "Failed to get user: "+err.Error())
return
}
settings, err := settings.Load(r.Context(), queries)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to load settings: "+err.Error())
return
}
if err := verifyResetToken(data.Token, &user, settings.Meta.AppURL, settings.RecordPasswordResetToken.Secret); err != nil {
errorJSON(w, http.StatusBadRequest, "Invalid or expired reset token: "+err.Error())
return
}
passwordHash, tokenKey, err := password.Hash(data.Password)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to hash password: "+err.Error())
return
}
now := time.Now().UTC()
if _, err := queries.UpdateUser(r.Context(), sqlc.UpdateUserParams{
ID: user.ID,
PasswordHash: &passwordHash,
TokenKey: &tokenKey,
LastResetSentAt: &now,
}); err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to update password: "+err.Error())
return
}
b, err := json.Marshal(map[string]any{
"message": "Password reset successfully",
})
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to create response: "+err.Error())
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(b)
}
}

View File

@@ -0,0 +1,94 @@
package auth
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
)
func TestService_createResetToken(t *testing.T) {
t.Parallel()
type args struct {
createUser *sqlc.User
tokenDuration time.Duration
waitDuration time.Duration
verifyUser *sqlc.User
}
tests := []struct {
name string
args args
wantErr assert.ErrorAssertionFunc
}{
{
name: "valid token",
args: args{
createUser: &sqlc.User{ID: "testuser", Tokenkey: "testtoken"},
tokenDuration: time.Hour,
waitDuration: 0,
verifyUser: &sqlc.User{
ID: "testuser",
Tokenkey: "testtoken",
Updated: mustParse(t, "2006-01-02 15:04:05Z", "2025-06-02 19:18:06.292Z"),
},
},
wantErr: assert.NoError,
},
{
name: "expired token",
args: args{
createUser: &sqlc.User{ID: "testuser", Tokenkey: "testtoken"},
tokenDuration: 0,
waitDuration: time.Second,
verifyUser: &sqlc.User{
ID: "testuser",
Tokenkey: "testtoken",
Updated: mustParse(t, "2006-01-02 15:04:05Z", "2025-06-02 19:18:06.292Z"),
},
},
wantErr: assert.Error,
},
{
name: "invalid token",
args: args{
createUser: &sqlc.User{ID: "testuser", Tokenkey: "testtoken"},
tokenDuration: time.Hour,
waitDuration: 0,
verifyUser: &sqlc.User{
ID: "invaliduser",
Tokenkey: "invalidtoken",
Updated: mustParse(t, "2006-01-02 15:04:05Z", "2025-06-02 19:18:06.292Z"),
},
},
wantErr: assert.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := createResetTokenWithDuration(tt.args.createUser, "", "", tt.args.tokenDuration)
require.NoError(t, err, "createResetToken()")
time.Sleep(tt.args.waitDuration)
err = verifyResetToken(got, tt.args.verifyUser, "", "")
tt.wantErr(t, err, "verifyResetToken()")
})
}
}
func mustParse(t *testing.T, layout, value string) time.Time {
t.Helper()
parsed, err := time.Parse(layout, value)
require.NoError(t, err, "mustParse()")
return parsed
}

58
app/auth/server.go Normal file
View File

@@ -0,0 +1,58 @@
package auth
import (
"encoding/json"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
"github.com/SecurityBrewery/catalyst/app/mail"
)
func Server(queries *sqlc.Queries, mailer *mail.Mailer) http.Handler {
router := chi.NewRouter()
router.Get("/user", handleUser(queries))
router.Post("/local/login", handleLogin(queries))
router.Post("/local/reset-password-mail", handleResetPasswordMail(queries, mailer))
router.Post("/local/reset-password", handlePassword(queries))
return router
}
func handleUser(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
authorizationHeader := r.Header.Get("Authorization")
bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix)
user, _, err := verifyAccessToken(r.Context(), bearerToken, queries)
if err != nil {
_, _ = w.Write([]byte("null"))
return
}
permissions, err := queries.ListUserPermissions(r.Context(), user.ID)
if err != nil {
errorJSON(w, http.StatusInternalServerError, err.Error())
return
}
b, err := json.Marshal(map[string]any{
"user": user,
"permissions": permissions,
})
if err != nil {
errorJSON(w, http.StatusInternalServerError, err.Error())
return
}
r.Header.Set("Content-Type", "application/json")
_, _ = w.Write(b)
}
}

99
app/auth/server_local.go Normal file
View File

@@ -0,0 +1,99 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"golang.org/x/crypto/bcrypt"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
"github.com/SecurityBrewery/catalyst/app/settings"
)
var ErrUserInactive = errors.New("user is inactive")
func handleLogin(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
type loginData struct {
Email string `json:"email"`
Password string `json:"password"`
}
var data loginData
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
unauthorizedJSON(w, "Invalid request")
return
}
user, err := loginWithMail(r.Context(), data.Email, data.Password, queries)
if err != nil {
if errors.Is(err, ErrUserInactive) {
unauthorizedJSON(w, "User is inactive")
return
}
unauthorizedJSON(w, "Login failed")
return
}
permissions, err := queries.ListUserPermissions(r.Context(), user.ID)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to get user permissions")
return
}
settings, err := settings.Load(r.Context(), queries)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to load settings")
return
}
duration := time.Duration(settings.RecordAuthToken.Duration) * time.Second
token, err := CreateAccessToken(r.Context(), user, permissions, duration, queries)
if err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to create login token")
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]string{
"token": token,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
errorJSON(w, http.StatusInternalServerError, "Failed to encode response")
return
}
}
}
func loginWithMail(ctx context.Context, mail, password string, queries *sqlc.Queries) (*sqlc.User, error) {
user, err := queries.UserByEmail(ctx, &mail)
if err != nil {
return nil, fmt.Errorf("failed to find user by email %q: %w", mail, err)
}
if !user.Active {
return nil, ErrUserInactive
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Passwordhash), []byte(password)); err != nil {
return nil, fmt.Errorf("invalid credentials: %w", err)
}
return &user, nil
}

225
app/auth/token.go Normal file
View File

@@ -0,0 +1,225 @@
package auth
import (
"context"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
"github.com/SecurityBrewery/catalyst/app/settings"
)
const (
purposeAccess = "access"
purposeReset = "reset"
scopeReset = "reset"
)
func CreateAccessToken(ctx context.Context, user *sqlc.User, permissions []string, duration time.Duration, queries *sqlc.Queries) (string, error) {
settings, err := settings.Load(ctx, queries)
if err != nil {
return "", fmt.Errorf("failed to load settings: %w", err)
}
return createToken(user, duration, purposeAccess, permissions, settings.Meta.AppURL, settings.RecordAuthToken.Secret)
}
func createResetToken(user *sqlc.User, settings *settings.Settings) (string, error) {
duration := time.Duration(settings.RecordPasswordResetToken.Duration) * time.Second
return createResetTokenWithDuration(user, settings.Meta.AppURL, settings.RecordPasswordResetToken.Secret, duration)
}
func createResetTokenWithDuration(user *sqlc.User, url, appToken string, duration time.Duration) (string, error) {
return createToken(user, duration, purposeReset, []string{scopeReset}, url, appToken)
}
func createToken(user *sqlc.User, duration time.Duration, purpose string, scopes []string, url, appToken string) (string, error) {
if scopes == nil {
scopes = []string{}
}
claims := jwt.MapClaims{
"sub": user.ID,
"exp": time.Now().Add(duration).Unix(),
"iat": time.Now().Unix(),
"iss": url,
"purpose": purpose,
"scopes": scopes,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signingKey := user.Tokenkey + appToken
return token.SignedString([]byte(signingKey))
}
func verifyToken(tokenStr string, user *sqlc.User, url, appToken string) (jwt.MapClaims, error) { //nolint:cyclop
signingKey := user.Tokenkey + appToken
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected algorithm: %v", t.Header["alg"])
}
return []byte(signingKey), nil
})
if err != nil {
return nil, fmt.Errorf("failed to verify token: %w", err)
}
if !token.Valid {
return nil, fmt.Errorf("token invalid")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
iss, err := claims.GetIssuer()
if err != nil {
return nil, fmt.Errorf("failed to get issuer: %w", err)
}
if iss != url {
return nil, fmt.Errorf("token issued by a different server")
}
sub, err := claims.GetSubject()
if err != nil {
return nil, fmt.Errorf("failed to get subject: %w", err)
}
if sub != user.ID {
return nil, fmt.Errorf("token belongs to a different user")
}
iat, err := claims.GetExpirationTime()
if err != nil {
return nil, fmt.Errorf("failed to get expiration time: %w", err)
}
if iat.Before(time.Now()) {
return nil, fmt.Errorf("token expired")
}
return claims, nil
}
func verifyAccessToken(ctx context.Context, bearerToken string, queries *sqlc.Queries) (*sqlc.User, jwt.MapClaims, error) {
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
if err != nil {
return nil, nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, nil, fmt.Errorf("failed to parse token claims")
}
sub, err := claims.GetSubject()
if err != nil {
return nil, nil, fmt.Errorf("token invalid: %w", err)
}
user, err := queries.GetUser(ctx, sub)
if err != nil {
return nil, nil, fmt.Errorf("failed to retrieve user for subject %s: %w", sub, err)
}
settings, err := settings.Load(ctx, queries)
if err != nil {
return nil, nil, fmt.Errorf("failed to load settings: %w", err)
}
claims, err = verifyToken(bearerToken, &user, settings.Meta.AppURL, settings.RecordAuthToken.Secret)
if err != nil {
return nil, nil, fmt.Errorf("failed to verify token: %w", err)
}
if err := hasPurpose(claims, purposeAccess); err != nil {
return nil, nil, fmt.Errorf("failed to check scopes: %w", err)
}
return &user, claims, nil
}
func verifyResetToken(tokenStr string, user *sqlc.User, url, appToken string) error {
claims, err := verifyToken(tokenStr, user, url, appToken)
if err != nil {
return err
}
iat, err := claims.GetIssuedAt()
if err != nil {
return fmt.Errorf("failed to get issued at: %w", err)
}
lastUpdated := user.Updated // TODO: create a last reset at column
if iat.Before(lastUpdated) {
return fmt.Errorf("token already used")
}
if err := hasPurpose(claims, purposeReset); err != nil {
return fmt.Errorf("failed to check scopes: %w", err)
}
return nil
}
func hasPurpose(claim jwt.MapClaims, expectedPurpose string) error {
purpose, err := purpose(claim)
if err != nil {
return fmt.Errorf("failed to get purposes: %w", err)
}
if purpose != expectedPurpose {
return fmt.Errorf("token has wrong purpose: %s, expected: %s", purpose, expectedPurpose)
}
return nil
}
func purpose(claim jwt.MapClaims) (string, error) {
purposeClaim, ok := claim["purpose"]
if !ok {
return "", fmt.Errorf("no purpose found")
}
purpose, ok := purposeClaim.(string)
if !ok {
return "", fmt.Errorf("invalid purpose type")
}
return purpose, nil
}
func scopes(claim jwt.MapClaims) ([]string, error) {
scopesClaim, ok := claim["scopes"]
if !ok {
return nil, fmt.Errorf("no scopes found")
}
scopesSlice, ok := scopesClaim.([]any)
if !ok {
return nil, fmt.Errorf("invalid scopes claim type: %T", scopesClaim)
}
scopes := make([]string, 0, len(scopesSlice))
for _, scope := range scopesSlice {
scopeStr, ok := scope.(string)
if !ok {
return nil, fmt.Errorf("invalid scope claim element type: %T", scope)
}
scopes = append(scopes, scopeStr)
}
return scopes, nil
}

View File

@@ -0,0 +1,46 @@
package usercontext
import (
"context"
"net/http"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
)
type userKey struct{}
func UserRequest(r *http.Request, user *sqlc.User) *http.Request {
return r.WithContext(UserContext(r.Context(), user))
}
func UserContext(ctx context.Context, user *sqlc.User) context.Context {
return context.WithValue(ctx, userKey{}, user)
}
func UserFromContext(ctx context.Context) (*sqlc.User, bool) {
user, ok := ctx.Value(userKey{}).(*sqlc.User)
if !ok {
return nil, false
}
return user, true
}
type permissionKey struct{}
func PermissionRequest(r *http.Request, permissions []string) *http.Request {
return r.WithContext(PermissionContext(r.Context(), permissions))
}
func PermissionContext(ctx context.Context, permissions []string) context.Context {
return context.WithValue(ctx, permissionKey{}, permissions)
}
func PermissionFromContext(ctx context.Context) ([]string, bool) {
permissions, ok := ctx.Value(permissionKey{}).([]string)
if !ok {
return nil, false
}
return permissions, true
}

View File

@@ -0,0 +1,116 @@
package usercontext
import (
"net/http"
"reflect"
"testing"
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
)
func TestPermissionContext(t *testing.T) {
t.Parallel()
tests := []struct {
name string
user *sqlc.User
permissions []string
wantPerms []string
wantOk bool
}{
{
name: "Set and get permissions",
permissions: []string{"ticket:read", "ticket:write"},
wantPerms: []string{"ticket:read", "ticket:write"},
wantOk: true,
},
{
name: "No permissions set",
wantPerms: nil,
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Test context functions
ctx := PermissionContext(t.Context(), tt.permissions)
gotPerms, gotOk := PermissionFromContext(ctx)
if !reflect.DeepEqual(gotPerms, tt.wantPerms) {
t.Errorf("PermissionFromContext() got perms = %v, want %v", gotPerms, tt.wantPerms)
}
if gotOk != tt.wantOk {
t.Errorf("PermissionFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
}
// Test request functions
req := &http.Request{}
req = PermissionRequest(req, tt.permissions)
gotPerms, gotOk = PermissionFromContext(req.Context())
if !reflect.DeepEqual(gotPerms, tt.wantPerms) {
t.Errorf("PermissionFromContext() got perms = %v, want %v", gotPerms, tt.wantPerms)
}
if gotOk != tt.wantOk {
t.Errorf("PermissionFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestUserContext(t *testing.T) {
t.Parallel()
tests := []struct {
name string
user *sqlc.User
wantOk bool
}{
{
name: "Set and get user",
user: &sqlc.User{ID: "test-user"},
wantOk: true,
},
{
name: "No user set",
user: nil,
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Test context functions
ctx := UserContext(t.Context(), tt.user)
gotUser, gotOk := UserFromContext(ctx)
if !reflect.DeepEqual(gotUser, tt.user) {
t.Errorf("UserFromContext() got user = %v, want %v", gotUser, tt.user)
}
if gotOk != tt.wantOk {
t.Errorf("UserFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
}
// Test request functions
req := &http.Request{}
req = UserRequest(req, tt.user)
gotUser, gotOk = UserFromContext(req.Context())
if !reflect.DeepEqual(gotUser, tt.user) {
t.Errorf("UserFromContext() got user = %v, want %v", gotUser, tt.user)
}
if gotOk != tt.wantOk {
t.Errorf("UserFromContext() got ok = %v, want %v", gotOk, tt.wantOk)
}
})
}
}