#!/opt/saltstack/salt/bin/python3

# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.

"""
so-push-drainer
===============

Scheduled drainer for the active-push feature. Runs on the manager every
drain_interval seconds (default 15) via a salt schedule in salt/schedule.sls.

For each intent file under /opt/so/state/push_pending/*.json whose last_touch
is older than debounce_seconds, this script:
  * concatenates the actions lists from every ready intent
  * dedupes by (state or __highstate__, tgt, tgt_type)
  * dispatches a single `salt-run state.orchestrate orch.push_batch --async`
    with the deduped actions list passed as pillar kwargs
  * deletes the contributed intent files on successful dispatch

Reactor sls files (push_suricata, push_strelka, push_pillar) write intents
but never dispatch directly -- see plan
/home/mreeves/.claude/plans/goofy-marinating-hummingbird.md for the full design.
"""

import fcntl
import glob
import json
import logging
import logging.handlers
import os
import subprocess
import sys
import time

import salt.client

PENDING_DIR = '/opt/so/state/push_pending'
LOCK_FILE = os.path.join(PENDING_DIR, '.lock')
LOG_FILE = '/opt/so/log/salt/so-push-drainer.log'

HIGHSTATE_SENTINEL = '__highstate__'


def _make_logger():
    logger = logging.getLogger('so-push-drainer')
    logger.setLevel(logging.INFO)
    if not logger.handlers:
        os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
        handler = logging.handlers.RotatingFileHandler(
            LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=3,
        )
        handler.setFormatter(logging.Formatter(
            '%(asctime)s | %(levelname)s | %(message)s',
        ))
        logger.addHandler(handler)
    return logger


def _load_push_cfg():
    """Read the global:push pillar subtree via salt-call. Returns a dict."""
    caller = salt.client.Caller()
    cfg = caller.cmd('pillar.get', 'global:push', {})
    return cfg if isinstance(cfg, dict) else {}


def _read_intent(path, log):
    try:
        with open(path, 'r') as f:
            return json.load(f)
    except (IOError, ValueError) as exc:
        log.warning('cannot read intent %s: %s', path, exc)
        return None
    except Exception:
        log.exception('unexpected error reading %s', path)
        return None


def _dedupe_actions(actions):
    seen = set()
    deduped = []
    for action in actions:
        if not isinstance(action, dict):
            continue
        state_key = HIGHSTATE_SENTINEL if action.get('highstate') else action.get('state')
        tgt = action.get('tgt')
        tgt_type = action.get('tgt_type', 'compound')
        if not state_key or not tgt:
            continue
        key = (state_key, tgt, tgt_type)
        if key in seen:
            continue
        seen.add(key)
        deduped.append(action)
    return deduped


def _dispatch(actions, log):
    pillar_arg = json.dumps({'actions': actions})
    cmd = [
        'salt-run',
        'state.orchestrate',
        'orch.push_batch',
        'pillar={}'.format(pillar_arg),
        '--async',
    ]
    log.info('dispatching: %s', ' '.join(cmd[:3]) + ' pillar=<{} actions>'.format(len(actions)))
    try:
        result = subprocess.run(
            cmd, check=True, capture_output=True, text=True, timeout=60,
        )
    except subprocess.CalledProcessError as exc:
        log.error('dispatch failed (rc=%s): stdout=%s stderr=%s',
                  exc.returncode, exc.stdout, exc.stderr)
        return False
    except subprocess.TimeoutExpired:
        log.error('dispatch timed out after 60s')
        return False
    except Exception:
        log.exception('dispatch raised')
        return False
    log.info('dispatch accepted: %s', (result.stdout or '').strip())
    return True


def main():
    log = _make_logger()

    if not os.path.isdir(PENDING_DIR):
        # Nothing to do; reactors create the dir on first use.
        return 0

    try:
        push = _load_push_cfg()
    except Exception:
        log.exception('failed to read global:push pillar; aborting drain pass')
        return 1

    if not push.get('enabled', True):
        log.debug('push disabled; exiting')
        return 0

    debounce_seconds = int(push.get('debounce_seconds', 30))

    os.makedirs(PENDING_DIR, exist_ok=True)
    lock_fd = os.open(LOCK_FILE, os.O_CREAT | os.O_RDWR, 0o644)
    try:
        fcntl.flock(lock_fd, fcntl.LOCK_EX)

        intent_files = [
            p for p in sorted(glob.glob(os.path.join(PENDING_DIR, '*.json')))
            if os.path.basename(p) != '.lock'
        ]
        if not intent_files:
            return 0

        now = time.time()
        ready = []
        skipped = 0
        broken = []
        for path in intent_files:
            intent = _read_intent(path, log)
            if not isinstance(intent, dict):
                broken.append(path)
                continue
            last_touch = intent.get('last_touch', 0)
            if now - last_touch < debounce_seconds:
                skipped += 1
                continue
            ready.append((path, intent))

        for path in broken:
            try:
                os.unlink(path)
            except OSError:
                pass

        if not ready:
            if skipped:
                log.debug('no ready intents (%d still in debounce window)', skipped)
            return 0

        combined_actions = []
        oldest_first_touch = now
        all_paths = []
        for path, intent in ready:
            combined_actions.extend(intent.get('actions', []) or [])
            first = intent.get('first_touch', now)
            if first < oldest_first_touch:
                oldest_first_touch = first
            all_paths.extend(intent.get('paths', []) or [])

        deduped = _dedupe_actions(combined_actions)
        if not deduped:
            log.warning('%d intent(s) had no usable actions; clearing', len(ready))
            for path, _ in ready:
                try:
                    os.unlink(path)
                except OSError:
                    pass
            return 0

        debounce_duration = now - oldest_first_touch
        log.info(
            'draining %d intent(s): %d action(s) after dedupe (raw=%d), '
            'debounce_duration=%.1fs, paths=%s',
            len(ready), len(deduped), len(combined_actions),
            debounce_duration, all_paths[:20],
        )

        if not _dispatch(deduped, log):
            log.warning('dispatch failed; leaving intent files in place for retry')
            return 1

        for path, _ in ready:
            try:
                os.unlink(path)
            except OSError:
                log.exception('failed to remove drained intent %s', path)

        return 0
    finally:
        try:
            fcntl.flock(lock_fd, fcntl.LOCK_UN)
        finally:
            os.close(lock_fd)


if __name__ == '__main__':
    sys.exit(main())
