mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-26 17:03:06 +01:00
refactor: remove pocketbase (#1138)
This commit is contained in:
79
app/router/demomode.go
Normal file
79
app/router/demomode.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
)
|
||||
|
||||
func demoMode(queries *sqlc.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if isCriticalPath(r) && isCriticalMethod(r) && isDemoMode(r.Context(), queries) {
|
||||
http.Error(w, "Cannot modify reactions or files in demo mode", http.StatusForbidden)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func isCriticalPath(r *http.Request) bool {
|
||||
// Define critical paths that should not be accessed in demo mode
|
||||
criticalPaths := []string{
|
||||
"/api/files",
|
||||
"/api/groups",
|
||||
"/api/reactions",
|
||||
"/api/settings",
|
||||
"/api/users",
|
||||
"/api/webhooks",
|
||||
}
|
||||
|
||||
for _, path := range criticalPaths {
|
||||
if strings.Contains(r.URL.Path, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isCriticalMethod(r *http.Request) bool {
|
||||
return !slices.Contains([]string{http.MethodHead, http.MethodGet}, r.Method)
|
||||
}
|
||||
|
||||
func isDemoMode(ctx context.Context, queries *sqlc.Queries) bool {
|
||||
var demoMode bool
|
||||
|
||||
if err := database.Paginate(ctx, func(ctx context.Context, offset, limit int64) (nextPage bool, err error) {
|
||||
slog.InfoContext(ctx, "Checking for demo mode", "offset", offset, "limit", limit)
|
||||
|
||||
features, err := queries.ListFeatures(ctx, sqlc.ListFeaturesParams{Offset: offset, Limit: limit})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, feature := range features {
|
||||
if feature.Key == "demo" {
|
||||
demoMode = true
|
||||
|
||||
return false, nil // Stop pagination if demo mode is found
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}); err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to check demo mode", "error", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return demoMode
|
||||
}
|
||||
115
app/router/demomode_test.go
Normal file
115
app/router/demomode_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/data"
|
||||
)
|
||||
|
||||
func Test_isCriticalPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
{"/api/reactions/1", true},
|
||||
{"/api/files/1", true},
|
||||
{"/api/other", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
|
||||
assert.Equal(t, tt.want, isCriticalPath(req))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isCriticalMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
method string
|
||||
want bool
|
||||
}{
|
||||
{http.MethodPost, true},
|
||||
{http.MethodPut, true},
|
||||
{http.MethodGet, false},
|
||||
{http.MethodHead, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(tt.method, "/", nil)
|
||||
assert.Equal(t, tt.want, isCriticalMethod(req))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isDemoMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
queries := data.NewTestDB(t, t.TempDir())
|
||||
assert.False(t, isDemoMode(t.Context(), queries))
|
||||
|
||||
_, err := queries.CreateFeature(t.Context(), "demo")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, isDemoMode(t.Context(), queries))
|
||||
}
|
||||
|
||||
func Test_demoModeMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
queries := data.NewTestDB(t, t.TempDir())
|
||||
mw := demoMode(queries)
|
||||
nextCalled := false
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
nextCalled = true
|
||||
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
})
|
||||
|
||||
// not demo mode
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/reactions", nil).WithContext(t.Context())
|
||||
mw(next).ServeHTTP(rr, req)
|
||||
assert.True(t, nextCalled)
|
||||
assert.Equal(t, http.StatusTeapot, rr.Code)
|
||||
|
||||
// enable demo mode
|
||||
_, err := queries.CreateFeature(t.Context(), "demo")
|
||||
require.NoError(t, err)
|
||||
|
||||
nextCalled = false
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/reactions", nil).WithContext(t.Context())
|
||||
mw(next).ServeHTTP(rr, req)
|
||||
assert.False(t, nextCalled)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
|
||||
// non critical path
|
||||
nextCalled = false
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/other", nil).WithContext(t.Context())
|
||||
mw(next).ServeHTTP(rr, req)
|
||||
assert.True(t, nextCalled)
|
||||
assert.Equal(t, http.StatusTeapot, rr.Code)
|
||||
}
|
||||
|
||||
func Test_handlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
queries := data.NewTestDB(t, t.TempDir())
|
||||
|
||||
// healthHandler
|
||||
healthRR := httptest.NewRecorder()
|
||||
|
||||
healthReq := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context())
|
||||
healthHandler(queries)(healthRR, healthReq)
|
||||
assert.Equal(t, http.StatusOK, healthRR.Code)
|
||||
assert.Equal(t, "OK", healthRR.Body.String())
|
||||
}
|
||||
37
app/router/http.go
Normal file
37
app/router/http.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/ui"
|
||||
)
|
||||
|
||||
func staticFiles(w http.ResponseWriter, r *http.Request) {
|
||||
if devServer := os.Getenv("UI_DEVSERVER"); devServer != "" {
|
||||
u, _ := url.Parse(devServer)
|
||||
|
||||
r.Host = r.URL.Host
|
||||
|
||||
httputil.NewSingleHostReverseProxy(u).ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
vueStatic(w, r)
|
||||
}
|
||||
|
||||
func vueStatic(w http.ResponseWriter, r *http.Request) {
|
||||
handler := http.FileServer(http.FS(ui.UI()))
|
||||
|
||||
if strings.HasPrefix(r.URL.Path, "/ui/assets/") {
|
||||
handler = http.StripPrefix("/ui", handler)
|
||||
} else {
|
||||
r.URL.Path = "/"
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
30
app/router/http_test.go
Normal file
30
app/router/http_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStaticFiles_DevServer(t *testing.T) {
|
||||
t.Setenv("UI_DEVSERVER", "http://localhost:1234")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/ui/assets/test.js", nil)
|
||||
|
||||
// This will try to proxy, but since the dev server isn't running, it should not panic
|
||||
// We just want to make sure it doesn't crash
|
||||
staticFiles(rec, r)
|
||||
}
|
||||
|
||||
func TestStaticFiles_VueStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/ui/assets/test.js", nil)
|
||||
staticFiles(rec, r)
|
||||
// Should not panic, and should serve something (even if it's a 404)
|
||||
if rec.Result().StatusCode == 0 {
|
||||
t.Error("expected a status code from vueStatic")
|
||||
}
|
||||
}
|
||||
69
app/router/router.go
Normal file
69
app/router/router.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/google/martian/v3/cors"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/mail"
|
||||
"github.com/SecurityBrewery/catalyst/app/service"
|
||||
"github.com/SecurityBrewery/catalyst/app/upload"
|
||||
)
|
||||
|
||||
func New(service *service.Service, queries *sqlc.Queries, uploader *upload.Uploader, mailer *mail.Mailer) (*chi.Mux, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// middleware for the router
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.Handler(cors.NewHandler(next))
|
||||
})
|
||||
r.Use(demoMode(queries))
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Timeout(time.Second * 60))
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
// base routes
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui/", http.StatusFound)
|
||||
})
|
||||
r.Get("/ui/*", staticFiles)
|
||||
r.Get("/health", healthHandler(queries))
|
||||
|
||||
// auth routes
|
||||
r.Mount("/auth", auth.Server(queries, mailer))
|
||||
|
||||
// API routes
|
||||
r.With(auth.Middleware(queries)).Mount("/api", http.StripPrefix("/api", service))
|
||||
|
||||
uploadHandler, err := tusRoutes(queries, uploader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Mount("/files", http.StripPrefix("/files", uploadHandler))
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func healthHandler(queries *sqlc.Queries) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := queries.ListFeatures(r.Context(), sqlc.ListFeaturesParams{Offset: 0, Limit: 100}); err != nil {
|
||||
slog.ErrorContext(r.Context(), "Failed to get flags", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
}
|
||||
}
|
||||
93
app/router/tus.go
Normal file
93
app/router/tus.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/tus/tusd/v2/pkg/filelocker"
|
||||
tusd "github.com/tus/tusd/v2/pkg/handler"
|
||||
"github.com/tus/tusd/v2/pkg/rootstore"
|
||||
|
||||
"github.com/SecurityBrewery/catalyst/app/auth"
|
||||
"github.com/SecurityBrewery/catalyst/app/database"
|
||||
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
||||
"github.com/SecurityBrewery/catalyst/app/upload"
|
||||
)
|
||||
|
||||
func tusRoutes(queries *sqlc.Queries, u *upload.Uploader) (http.Handler, error) {
|
||||
store := rootstore.New(u.Root)
|
||||
locker := filelocker.New(u.Root.Name())
|
||||
composer := tusd.NewStoreComposer()
|
||||
store.UseIn(composer)
|
||||
locker.UseIn(composer)
|
||||
|
||||
// Create a new HTTP handler for the tusd server by providing a configuration.
|
||||
// The StoreComposer property must be set to allow the handler to function.
|
||||
handler, err := tusd.NewHandler(tusd.Config{
|
||||
BasePath: "/files/",
|
||||
StoreComposer: composer,
|
||||
NotifyCompleteUploads: true,
|
||||
PreUploadCreateCallback: func(hook tusd.HookEvent) (tusd.HTTPResponse, tusd.FileInfoChanges, error) {
|
||||
// This hook is called before an upload is created. You can use it to
|
||||
// modify the upload information, for example to set a custom ID or
|
||||
// storage path.
|
||||
id := database.GenerateID("")
|
||||
|
||||
if hook.Upload.Storage == nil {
|
||||
hook.Upload.Storage = make(map[string]string)
|
||||
}
|
||||
|
||||
filename, ok := hook.Upload.MetaData["filename"]
|
||||
if !ok || filename == "" {
|
||||
filename = id
|
||||
}
|
||||
|
||||
_, filePath := u.Paths(id, filepath.Base(filename))
|
||||
|
||||
hook.Upload.Storage["Path"] = filePath
|
||||
|
||||
return tusd.HTTPResponse{}, tusd.FileInfoChanges{
|
||||
ID: id,
|
||||
Storage: hook.Upload.Storage,
|
||||
}, nil
|
||||
},
|
||||
PreFinishResponseCallback: func(hook tusd.HookEvent) (tusd.HTTPResponse, error) {
|
||||
filename, ok := hook.Upload.MetaData["filename"]
|
||||
if !ok || filename == "" {
|
||||
filename = hook.Upload.ID
|
||||
}
|
||||
|
||||
_, err := queries.InsertFile(hook.Context, sqlc.InsertFileParams{
|
||||
ID: hook.Upload.ID,
|
||||
Name: filename,
|
||||
Blob: path.Base(hook.Upload.Storage["Path"]),
|
||||
Size: float64(hook.Upload.Size),
|
||||
Ticket: hook.HTTPRequest.Header.Get("X-Ticket-ID"),
|
||||
Created: time.Now().UTC(),
|
||||
Updated: time.Now().UTC(),
|
||||
})
|
||||
|
||||
return tusd.HTTPResponse{}, err
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tusd handler: %w", err)
|
||||
}
|
||||
|
||||
// Start another goroutine for receiving events from the handler whenever
|
||||
// an upload is completed. The event will contains details about the upload
|
||||
// itself and the relevant HTTP request.
|
||||
go func() {
|
||||
for {
|
||||
event := <-handler.CompleteUploads
|
||||
slog.Info("Upload %s finished", "id", event.Upload.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
return chi.Chain(auth.Middleware(queries), auth.ValidateFileScopes).Handler(handler), nil
|
||||
}
|
||||
Reference in New Issue
Block a user