mirror of
https://github.com/SecurityBrewery/catalyst.git
synced 2025-12-06 07:12:46 +01:00
113 lines
2.6 KiB
Go
113 lines
2.6 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3" // import sqlite driver
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/SecurityBrewery/catalyst/app/database/sqlc"
|
|
)
|
|
|
|
const sqliteDriver = "sqlite3"
|
|
|
|
func DB(ctx context.Context, dir string) (*sqlc.Queries, func(), error) {
|
|
filename := filepath.Join(dir, "data.db")
|
|
|
|
slog.InfoContext(ctx, "Connecting to database", "path", filename)
|
|
|
|
// see https://briandouglas.ie/sqlite-defaults/ for more details
|
|
pragmas := []string{
|
|
// Enable WAL mode for better concurrency
|
|
"journal_mode=WAL",
|
|
// Enable synchronous mode for better data integrity
|
|
"synchronous=NORMAL",
|
|
// Set busy timeout to 5 seconds
|
|
"busy_timeout=5000",
|
|
// Set cache size to 20MB
|
|
"cache_size=-20000",
|
|
// Enable foreign key checks
|
|
"foreign_keys=ON",
|
|
// Enable incremental vacuuming
|
|
"auto_vacuum=INCREMENTAL",
|
|
// Set temp store to memory
|
|
"temp_store=MEMORY",
|
|
// Set mmap size to 2GB
|
|
"mmap_size=2147483648",
|
|
// Set page size to 8192
|
|
"page_size=8192",
|
|
}
|
|
|
|
_ = os.MkdirAll(filepath.Dir(filename), 0o755)
|
|
|
|
write, err := sql.Open(sqliteDriver, fmt.Sprintf("file:%s", filename))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
write.SetMaxOpenConns(1)
|
|
write.SetConnMaxIdleTime(time.Minute)
|
|
|
|
for _, pragma := range pragmas {
|
|
if _, err := write.ExecContext(ctx, fmt.Sprintf("PRAGMA %s", pragma)); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to set pragma %s: %w", pragma, err)
|
|
}
|
|
}
|
|
|
|
read, err := sql.Open(sqliteDriver, fmt.Sprintf("file:%s?mode=ro", filename))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
read.SetMaxOpenConns(100)
|
|
read.SetConnMaxIdleTime(time.Minute)
|
|
|
|
queries := sqlc.New(read, write)
|
|
|
|
return queries, func() {
|
|
if err := read.Close(); err != nil {
|
|
slog.Error("failed to close read connection", "error", err)
|
|
}
|
|
|
|
if err := write.Close(); err != nil {
|
|
slog.Error("failed to close write connection", "error", err)
|
|
}
|
|
}, nil
|
|
}
|
|
|
|
func TestDB(t *testing.T, dir string) *sqlc.Queries {
|
|
queries, cleanup, err := DB(t.Context(), filepath.Join(dir, "data.db"))
|
|
require.NoError(t, err)
|
|
t.Cleanup(cleanup)
|
|
|
|
return queries
|
|
}
|
|
|
|
func GenerateID(prefix string) string {
|
|
return strings.ToLower(prefix) + randomstring(12)
|
|
}
|
|
|
|
const base32alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
|
|
func randomstring(l int) string {
|
|
rand.Text()
|
|
|
|
src := make([]byte, l)
|
|
_, _ = rand.Read(src)
|
|
|
|
for i := range src {
|
|
src[i] = base32alphabet[int(src[i])%len(base32alphabet)]
|
|
}
|
|
|
|
return string(src)
|
|
}
|