#!/usr/bin/env python3

# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.

"""
so-pillar-import — populate the so_pillar.* schema in so-postgres from the
on-disk Salt pillar tree.

Reads /opt/so/saltstack/local/pillar/, decomposes each .sls file into a
(scope, role|minion_id, pillar_path, data) tuple, and UPSERTs it into
so_pillar.pillar_entry. Idempotent — re-running with no SLS edits produces
no version bumps because the audit trigger only writes a row when data
actually changes.

Bootstrap and mine-driven files are skipped (see EXCLUDE_BASENAMES /
EXCLUDE_PREFIXES below). Files containing Jinja templates ({% or {{) are
also skipped — those stay disk-authoritative and ext_pillar_first: False
means they render before the PG overlay anyway.

All SQL goes through `docker exec so-postgres psql` so no separate DSN
config is required at first-install time. Designed to be called by
salt/postgres/schema_pillar.sls (initial seed) and by salt/manager/tools/
sbin/so-minion (per-minion sync on add/delete).
"""

import argparse
import json
import os
import shlex
import subprocess
import sys
from pathlib import Path

import yaml


PILLAR_LOCAL_ROOT = Path("/opt/so/saltstack/local/pillar")
PILLAR_DEFAULT_ROOT = Path("/opt/so/saltstack/default/pillar")
DOCKER_CONTAINER = "so-postgres"
PG_SUPERUSER = "postgres"
PG_DATABASE = "securityonion"

# Files that must NEVER move to Postgres. These are read by Salt before
# Postgres is reachable, or contain renderer-time computed values (mine, etc.).
EXCLUDE_BASENAMES = {
    "secrets.sls",
    "auth.sls",          # postgres/auth.sls bootstrap
    "top.sls",
}
# Filename prefixes to skip — these are renderer-time computed pillars
# (Salt mine, file_exists guards, etc.) that have to stay on disk.
EXCLUDE_PATH_FRAGMENTS = (
    "/elasticsearch/nodes.sls",
    "/redis/nodes.sls",
    "/kafka/nodes.sls",
    "/hypervisor/nodes.sls",
    "/logstash/nodes.sls",
    "/node_data/ips.sls",
    "/postgres/auth.sls",
    "/elasticsearch/auth.sls",
    "/kibana/secrets.sls",
)


def log(level, msg):
    print(f"[{level}] {msg}", file=sys.stderr)


def is_jinja_templated(content_bytes):
    return b"{%" in content_bytes or b"{{" in content_bytes


def classify(path):
    """Return (scope, role_name, minion_id, pillar_path) for a pillar file
    or None to skip it. role_name is None for now — the importer leaves role
    membership to the so_pillar.minion trigger and the salt/auth reactor."""
    rel_str = str(path)
    if path.name in EXCLUDE_BASENAMES:
        return None
    for frag in EXCLUDE_PATH_FRAGMENTS:
        if frag in rel_str:
            return None

    # /local/pillar/minions/<id>.sls  or  adv_<id>.sls
    if path.parent.name == "minions":
        stem = path.stem  # filename without .sls
        if stem.startswith("adv_"):
            mid = stem[4:]
            return ("minion", None, mid, f"minions.adv_{mid}")
        return ("minion", None, stem, f"minions.{stem}")

    # /local/pillar/<section>/<file>.sls
    if path.parent.parent == PILLAR_LOCAL_ROOT or path.parent.parent == PILLAR_DEFAULT_ROOT:
        section = path.parent.name
        stem = path.stem
        # Only soc_<section>.sls and adv_<section>.sls are SOC-managed pillar
        # surfaces. Other files (e.g. nodes.sls, auth.sls, *.token) are
        # either covered by EXCLUDE_PATH_FRAGMENTS or are bootstrap surfaces
        # we leave alone for now.
        if stem.startswith("soc_") or stem.startswith("adv_"):
            return ("global", None, None, f"{section}.{stem}")
        return None

    return None


def parse_yaml_file(path):
    with open(path, "rb") as f:
        content = f.read()
    if not content.strip():
        return {}
    if is_jinja_templated(content):
        return None
    data = yaml.safe_load(content)
    if data is None:
        return {}
    if not isinstance(data, dict):
        return {"_raw": data}
    return data


def derive_node_type(minion_id):
    """Conventional Security Onion minion ids are <host>_<role>. Take the
    last underscore-delimited token as the canonical role suffix."""
    parts = minion_id.rsplit("_", 1)
    if len(parts) == 2:
        return parts[1]
    return None


def docker_psql(sql, *, db=PG_DATABASE, user=PG_SUPERUSER, on_error_stop=True, capture=True):
    """Run sql via docker exec ... psql. Returns stdout as str."""
    args = [
        "docker", "exec", "-i", DOCKER_CONTAINER,
        "psql", "-U", user, "-d", db, "-tA", "-q",
    ]
    if on_error_stop:
        args += ["-v", "ON_ERROR_STOP=1"]
    proc = subprocess.run(
        args, input=sql.encode(),
        capture_output=capture, check=False,
    )
    if proc.returncode != 0:
        sys.stderr.write(proc.stderr.decode(errors="replace"))
        raise RuntimeError(f"docker exec psql failed (rc={proc.returncode})")
    return proc.stdout.decode(errors="replace")


def upsert_minion(minion_id, node_type):
    sql = (
        "INSERT INTO so_pillar.minion (minion_id, node_type) "
        f"VALUES ({pg_str(minion_id)}, {pg_str(node_type) if node_type else 'NULL'}) "
        "ON CONFLICT (minion_id) DO UPDATE SET node_type = EXCLUDED.node_type;"
    )
    docker_psql(sql)


def delete_minion(minion_id):
    """CASCADE removes pillar_entry + role_member rows."""
    sql = f"DELETE FROM so_pillar.minion WHERE minion_id = {pg_str(minion_id)};"
    docker_psql(sql)


def upsert_pillar_entry(scope, role_name, minion_id, pillar_path, data, reason):
    """Insert or update the row keyed by the partial unique index that
    matches scope. Audit trigger handles history; versioning trigger bumps
    version only when data changes."""
    data_json = json.dumps(data)
    role_sql = pg_str(role_name) if role_name else "NULL"
    minion_sql = pg_str(minion_id) if minion_id else "NULL"
    reason_sql = pg_str(reason)

    if scope == "global":
        conflict = "(pillar_path) WHERE scope='global'"
    elif scope == "role":
        conflict = "(role_name, pillar_path) WHERE scope='role'"
    elif scope == "minion":
        conflict = "(minion_id, pillar_path) WHERE scope='minion'"
    else:
        raise ValueError(f"unknown scope {scope!r}")

    sql = (
        "BEGIN;\n"
        f"SELECT set_config('so_pillar.change_reason', {reason_sql}, true);\n"
        f"INSERT INTO so_pillar.pillar_entry "
        f"(scope, role_name, minion_id, pillar_path, data, change_reason) "
        f"VALUES ({pg_str(scope)}, {role_sql}, {minion_sql}, {pg_str(pillar_path)}, {pg_jsonb(data_json)}, {reason_sql}) "
        f"ON CONFLICT {conflict} DO UPDATE "
        f"SET data = EXCLUDED.data, change_reason = EXCLUDED.change_reason;\n"
        "COMMIT;\n"
    )
    docker_psql(sql)


def pg_str(s):
    """Escape a Python str for inclusion in literal SQL. Pillar content has
    already been validated as YAML; we just need standard SQL escaping."""
    if s is None:
        return "NULL"
    return "'" + str(s).replace("'", "''") + "'"


def pg_jsonb(json_str):
    return pg_str(json_str) + "::jsonb"


def walk_pillar_root(root, paths):
    if not root.is_dir():
        return
    for path in root.rglob("*.sls"):
        if path.is_file():
            paths.append(path)


def import_minion(minion_id, node_type, dry_run, reason):
    """Re-import every pillar file for a single minion."""
    if not minion_id:
        raise ValueError("minion_id required for --scope minion")

    upsert_minion(minion_id, node_type)
    log("INFO", f"Upserted minion row {minion_id} (node_type={node_type})")

    targets = [
        PILLAR_LOCAL_ROOT / "minions" / f"{minion_id}.sls",
        PILLAR_LOCAL_ROOT / "minions" / f"adv_{minion_id}.sls",
    ]
    for path in targets:
        if not path.exists():
            log("INFO", f"  (no file at {path})")
            continue
        klass = classify(path)
        if not klass:
            log("INFO", f"  skip {path} (excluded)")
            continue
        scope, role, mid, pillar_path = klass
        data = parse_yaml_file(path)
        if data is None:
            log("WARN", f"  skip {path} (Jinja-templated; stays disk-only)")
            continue
        if dry_run:
            log("DRY", f"  would upsert {scope}/{pillar_path} = {len(json.dumps(data))} bytes")
            continue
        upsert_pillar_entry(scope, role, mid, pillar_path, data, reason)
        log("INFO", f"  imported {scope}/{pillar_path}")


def import_all(dry_run, reason):
    """Walk the entire local pillar tree and import every eligible file."""
    paths = []
    walk_pillar_root(PILLAR_LOCAL_ROOT, paths)

    imported = 0
    skipped = 0
    minions_seen = set()

    for path in sorted(paths):
        klass = classify(path)
        if not klass:
            skipped += 1
            continue
        scope, role, minion_id, pillar_path = klass
        data = parse_yaml_file(path)
        if data is None:
            log("WARN", f"skip {path} (Jinja-templated; stays disk-only)")
            skipped += 1
            continue

        if scope == "minion" and minion_id not in minions_seen:
            node_type = derive_node_type(minion_id)
            if not dry_run:
                upsert_minion(minion_id, node_type)
            minions_seen.add(minion_id)

        if dry_run:
            log("DRY", f"would upsert {scope}/{pillar_path} ({len(json.dumps(data))} bytes)")
        else:
            upsert_pillar_entry(scope, role, minion_id, pillar_path, data, reason)
            log("INFO", f"imported {scope}/{pillar_path}")
        imported += 1

    log("INFO", f"done: {imported} imported, {skipped} skipped")


def main():
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument("--scope", choices=("global", "role", "minion", "all"), default="all")
    ap.add_argument("--minion-id")
    ap.add_argument("--node-type", help="override node_type for --scope minion (default: derived from minion_id)")
    ap.add_argument("--delete", action="store_true",
                    help="With --scope minion, remove the minion row (and its pillar rows via CASCADE)")
    ap.add_argument("--dry-run", action="store_true")
    ap.add_argument("--diff", action="store_true",
                    help="(reserved) print structural diffs vs current DB content")
    ap.add_argument("--yes", action="store_true",
                    help="Skip confirmation prompts (currently unused; reserved)")
    ap.add_argument("--reason", default="so-pillar-import",
                    help="change_reason recorded in pillar_entry_history")
    args = ap.parse_args()

    try:
        if args.scope == "minion":
            if not args.minion_id:
                ap.error("--minion-id required when --scope minion")
            if args.delete:
                if args.dry_run:
                    log("DRY", f"would delete {args.minion_id}")
                else:
                    delete_minion(args.minion_id)
                    log("INFO", f"deleted {args.minion_id}")
            else:
                node_type = args.node_type or derive_node_type(args.minion_id)
                import_minion(args.minion_id, node_type, args.dry_run, args.reason)
        elif args.scope == "all":
            import_all(args.dry_run, args.reason)
        else:
            log("ERROR", f"--scope {args.scope} not yet implemented; use --scope all or --scope minion")
            return 2
    except Exception as e:
        log("ERROR", str(e))
        return 1

    return 0


if __name__ == "__main__":
    sys.exit(main())
