#!/usr/bin/env python3

# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.

# Imports detection overrides (e.g. from so-detections-backup) into the so-detection
# index. Reads <publicId>.<ext> files (NDJSON, one override per line) from a source
# directory, looks up the matching detection by publicId+engine, validates each
# override against the same rules SOC enforces, dedupes against existing overrides
# (operational fields only), and appends new ones.

import argparse
import ipaddress
import json
import os
import re
import sys
from datetime import datetime

import requests
from requests.auth import HTTPBasicAuth
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

DEFAULT_INDEX = "so-detection"
AUTH_FILE = "/opt/so/conf/elasticsearch/curl.config"
ES_URL = "https://localhost:9200"

# Engines we know how to handle and the file extension the backup script writes.
ENGINES = {
    "suricata": "txt",
}

# Standard Suricata variables that ship with Security Onion. Anything else
# referenced in an override is "custom" and the user needs to make sure it
# exists in SOC Config before the override will function.
BUILTIN_SURICATA_VARS = {
    "$HOME_NET", "$EXTERNAL_NET",
    "$HTTP_SERVERS", "$DNS_SERVERS", "$SQL_SERVERS", "$SMTP_SERVERS",
    "$TELNET_SERVERS", "$AIM_SERVERS", "$DC_SERVERS", "$MODBUS_SERVER",
    "$MODBUS_CLIENT", "$ENIP_CLIENT", "$ENIP_SERVER",
    "$HTTP_PORTS", "$SHELLCODE_PORTS", "$ORACLE_PORTS", "$SSH_PORTS",
    "$FTP_PORTS", "$FILE_DATA_PORTS",
}

VAR_PATTERN = re.compile(r"\$[A-Z_][A-Z0-9_]*")

# Canonical valid values, per securityonion-soc/model/detection.go.
SURICATA_OVERRIDE_TYPES = {"suppress", "threshold", "modify"}
SUPPRESS_TRACKS = {"by_src", "by_dst", "by_either"}
THRESHOLD_TRACKS = {"by_src", "by_dst", "by_both"}
THRESHOLD_TYPES = {"limit", "threshold", "both"}

STALE_WARNING = """\
WARNING: so-detections-backup does not remove backup files when overrides are
deleted via the Security Onion web UI. As a result, files in the source
directory may represent overrides that were intentionally deleted and should
NOT be re-imported.

Before continuing, verify that the source directory reflects the overrides you
actually want imported. Remove any files corresponding to overrides you previously deleted.
"""


def make_session(auth_file):
    with open(auth_file, "r") as f:
        for line in f:
            if line.startswith("user ="):
                creds = line.split("=", 1)[1].strip().replace('"', "")
                user, _, password = creds.partition(":")
                session = requests.Session()
                session.auth = HTTPBasicAuth(user, password)
                session.headers.update({"Content-Type": "application/json"})
                session.verify = False
                return session
    raise RuntimeError(f"Could not find 'user =' line in {auth_file}")


def find_detection(session, index, public_id, engine):
    query = {
        "query": {"bool": {"must": [
            {"term": {"so_detection.publicId": public_id}},
            {"term": {"so_detection.engine": engine}},
        ]}},
        "size": 2,
    }
    r = session.get(f"{ES_URL}/{index}/_search", json=query)
    r.raise_for_status()
    hits = r.json().get("hits", {}).get("hits", [])
    if not hits:
        return None, None, None
    if len(hits) > 1:
        # Shouldn't happen — publicId is unique per engine — but flag it.
        print(f"  WARN: {len(hits)} detections matched publicId={public_id} engine={engine}; using first")
    hit = hits[0]
    existing = hit["_source"].get("so_detection", {}).get("overrides") or []
    return hit["_id"], hit["_index"], existing


def update_overrides(session, doc_index, doc_id, overrides):
    body = {"doc": {"so_detection": {"overrides": overrides}}}
    r = session.post(f"{ES_URL}/{doc_index}/_update/{doc_id}", json=body)
    r.raise_for_status()
    return r.json()


def dedupe_key(override):
    """Operational fields only, per Override.Equal() in detection.go.
    Excludes timestamps and isEnabled so re-imports don't appear unique."""
    t = override.get("type")
    if t == "suppress":
        return (t, override.get("track"), override.get("ip"))
    if t == "threshold":
        return (t, override.get("thresholdType"), override.get("track"),
                override.get("count"), override.get("seconds"))
    if t == "modify":
        return (t, override.get("regex"), override.get("value"))


def _validate_suricata_ip(ip):
    if not ip:
        return "ip cannot be empty"
    if ip.startswith("$"):
        return None
    if ip.startswith("[") and ip.endswith("]"):
        for part in ip[1:-1].split(","):
            err = _validate_single_ip(part.strip())
            if err:
                return f"invalid IP in list: {err}"
        return None
    return _validate_single_ip(ip)


def _validate_single_ip(ip):
    try:
        if "/" in ip:
            ipaddress.ip_network(ip, strict=False)
        else:
            ipaddress.ip_address(ip)
    except ValueError:
        return f"invalid IP/CIDR {ip!r}"
    return None


def validate_override(override, engine):
    """Mirror Override.Validate() from securityonion-soc/model/detection.go.
    Returns None on success, an error string otherwise."""
    t = override.get("type")
    if not t:
        return "override type is required"
    if t not in SURICATA_OVERRIDE_TYPES:
        return f"invalid type {t!r}: must be one of {sorted(SURICATA_OVERRIDE_TYPES)}"

    has = {k: override.get(k) is not None for k in
           ("regex", "value", "thresholdType", "track", "ip", "count", "seconds", "customFilter")}

    if t == "suppress":
        if not has["ip"] or not has["track"]:
            return "suppress requires 'ip' and 'track'"
        if any(has[k] for k in ("regex", "value", "thresholdType", "count", "seconds", "customFilter")):
            return "suppress has unnecessary fields"
        if override["track"] not in SUPPRESS_TRACKS:
            return f"invalid track {override['track']!r}: must be one of {sorted(SUPPRESS_TRACKS)}"
        return _validate_suricata_ip(override["ip"])

    if t == "threshold":
        if not all(has[k] for k in ("thresholdType", "track", "count", "seconds")):
            return "threshold requires 'thresholdType', 'track', 'count', 'seconds'"
        if any(has[k] for k in ("regex", "value", "customFilter")):
            return "threshold has unnecessary fields"
        if override["thresholdType"] not in THRESHOLD_TYPES:
            return f"invalid thresholdType {override['thresholdType']!r}: must be one of {sorted(THRESHOLD_TYPES)}"
        if override["track"] not in THRESHOLD_TRACKS:
            return f"invalid track {override['track']!r}: must be one of {sorted(THRESHOLD_TRACKS)}"
        if not isinstance(override["count"], int) or override["count"] <= 0:
            return f"count must be a positive integer, got {override['count']!r}"
        if not isinstance(override["seconds"], int) or override["seconds"] <= 0:
            return f"seconds must be a positive integer, got {override['seconds']!r}"
        return None

    if t == "modify":
        if not has["regex"] or not has["value"]:
            return "modify requires 'regex' and 'value'"
        if any(has[k] for k in ("thresholdType", "track", "count", "seconds", "customFilter")):
            return "modify has unnecessary fields"
        try:
            re.compile(override["regex"])
        except re.error as e:
            return f"invalid regex: {e}"
        return None


def parse_overrides_file(path):
    """Parse a file written by so-detections-backup.py: NDJSON, one override
    per line. Returns a list of (override_dict, line_number)."""
    overrides = []
    with open(path, "r") as f:
        for i, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            overrides.append((json.loads(line), i))
    return overrides


def describe(override):
    """Human-readable summary of the operational fields for a given override type."""
    t = override.get("type")
    if t == "suppress":
        return f"type=suppress track={override.get('track')} ip={override.get('ip')}"
    if t == "threshold":
        return (f"type=threshold track={override.get('track')} "
                f"thresholdType={override.get('thresholdType')} "
                f"count={override.get('count')} seconds={override.get('seconds')}")
    if t == "modify":
        return f"type=modify regex={override.get('regex')!r}"


def collect_custom_vars(override):
    found = set()
    for value in override.values():
        if isinstance(value, str):
            for match in VAR_PATTERN.findall(value):
                if match not in BUILTIN_SURICATA_VARS:
                    found.add(match)
    return found


def parse_args():
    p = argparse.ArgumentParser(
        description="Import detection overrides into the so-detection index.",
    )
    p.add_argument("--source", "-s", required=True,
                   help="Source directory containing <publicId>.<ext> override files.")
    p.add_argument("--engine", "-e", default="suricata", choices=list(ENGINES.keys()),
                   help="Detection engine (default: suricata).")
    p.add_argument("--dry-run", "-n", action="store_true",
                   help="Print what would happen without writing to Elasticsearch.")
    p.add_argument("--no-import-note", action="store_true",
                   help="Do not prepend '[Imported YYYY-MM-DD] ' to the override note.")
    p.add_argument("--index", "-i", default=DEFAULT_INDEX,
                   help=f"Elasticsearch index to update (default: {DEFAULT_INDEX}).")
    return p.parse_args()


def confirm_proceed(args):
    """Show the stale-backup warning. Dry-run prints it and continues. Real
    runs require the user typing 'yes' at the prompt."""
    print(STALE_WARNING)
    if args.dry_run:
        print("(dry-run: no acknowledgement required)\n")
        return True
    answer = input("Type 'yes' to acknowledge and continue: ").strip().lower()
    print()
    return answer == "yes"


def main():
    args = parse_args()

    if not os.path.isdir(args.source):
        print(f"ERROR: source directory not found: {args.source}", file=sys.stderr)
        sys.exit(1)

    extension = ENGINES[args.engine]
    files = sorted(f for f in os.listdir(args.source) if f.endswith(f".{extension}"))
    if not files:
        print(f"No *.{extension} files found in {args.source}")
        sys.exit(0)

    if not confirm_proceed(args):
        print("Aborted.")
        sys.exit(1)

    session = make_session(AUTH_FILE)
    today = datetime.now().strftime("%Y-%m-%d")
    note_prefix = "" if args.no_import_note else f"[Imported {today}] "

    counts = {"added": 0, "skipped_dedupe": 0, "skipped_not_found": 0, "invalid": 0, "error": 0}
    custom_vars = set()

    mode = "DRY-RUN" if args.dry_run else "IMPORT"
    print(f"[{mode}] engine={args.engine} source={args.source} index={args.index}\n")

    for filename in files:
        public_id = os.path.splitext(filename)[0]
        path = os.path.join(args.source, filename)
        print(f"{public_id}:")

        try:
            new_overrides = parse_overrides_file(path)
        except (json.JSONDecodeError, OSError) as e:
            print(f"  ERROR: could not parse {filename}: {e}")
            counts["error"] += 1
            continue

        if not new_overrides:
            print("  SKIP: empty file")
            continue

        try:
            doc_id, doc_index, existing = find_detection(session, args.index, public_id, args.engine)
        except requests.HTTPError as e:
            print(f"  ERROR: search failed: {e}")
            counts["error"] += 1
            continue

        if doc_id is None:
            print(f"  WARN: no detection found for publicId={public_id} engine={args.engine}; skipping")
            counts["skipped_not_found"] += len(new_overrides)
            continue

        existing_keys = {dedupe_key(o) for o in existing}
        merged = list(existing)
        added_this_file = 0

        for override, line_no in new_overrides:
            err = validate_override(override, args.engine)
            if err:
                print(f"  INVALID (line {line_no}): {err}")
                counts["invalid"] += 1
                continue

            custom_vars.update(collect_custom_vars(override))
            key = dedupe_key(override)
            if key in existing_keys:
                print(f"  SKIP (line {line_no}): duplicate of existing override [{describe(override)}]")
                counts["skipped_dedupe"] += 1
                continue

            if note_prefix:
                override = dict(override)
                override["note"] = note_prefix + (override.get("note") or "")

            merged.append(override)
            existing_keys.add(key)
            added_this_file += 1
            print(f"  ADD (line {line_no}): {describe(override)}")

        if added_this_file == 0:
            continue

        if args.dry_run:
            print(f"  DRY-RUN: would update {doc_index}/{doc_id} "
                  f"({len(existing)} existing → {len(merged)} total)")
            counts["added"] += added_this_file
            continue

        try:
            update_overrides(session, doc_index, doc_id, merged)
            print(f"  UPDATED {doc_index}/{doc_id} ({len(existing)} → {len(merged)})")
            counts["added"] += added_this_file
        except requests.HTTPError as e:
            print(f"  ERROR: update failed: {e}")
            counts["error"] += 1

    print()
    print("=" * 60)
    print(f"Summary ({mode}):")
    print(f"  Overrides added:           {counts['added']}")
    print(f"  Skipped (already present): {counts['skipped_dedupe']}")
    print(f"  Skipped (no detection):    {counts['skipped_not_found']}")
    print(f"  Invalid (failed checks):   {counts['invalid']}")
    print(f"  Errors:                    {counts['error']}")

    if custom_vars:
        print()
        print("WARNING: detected custom Suricata variables in imported overrides:")
        for v in sorted(custom_vars):
            print(f"  {v}")
        print("If any of these are not already defined in SOC Config (Suricata variables),")
        print("you must add them manually before the rules will function correctly.")

    sys.exit(0 if counts["error"] == 0 and counts["invalid"] == 0 else 1)


if __name__ == "__main__":
    main()
