diff --git a/salt/manager/tools/sbin/so-detections-overrides-import b/salt/manager/tools/sbin/so-detections-overrides-import new file mode 100755 index 000000000..1f32bf04a --- /dev/null +++ b/salt/manager/tools/sbin/so-detections-overrides-import @@ -0,0 +1,391 @@ +#!/usr/bin/env python3 + +# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one +# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at +# https://securityonion.net/license; you may not use this file except in compliance with the +# Elastic License 2.0. + +# Imports detection overrides (e.g. from so-detections-backup) into the so-detection +# index. Reads . files (NDJSON, one override per line) from a source +# directory, looks up the matching detection by publicId+engine, validates each +# override against the same rules SOC enforces, dedupes against existing overrides +# (operational fields only), and appends new ones. + +import argparse +import ipaddress +import json +import os +import re +import sys +from datetime import datetime + +import requests +from requests.auth import HTTPBasicAuth +import urllib3 + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +DEFAULT_INDEX = "so-detection" +AUTH_FILE = "/opt/so/conf/elasticsearch/curl.config" +ES_URL = "https://localhost:9200" + +# Engines we know how to handle and the file extension the backup script writes. +ENGINES = { + "suricata": "txt", + "sigma": "yaml", +} + +# Standard Suricata variables that ship with Security Onion. Anything else +# referenced in an override is "custom" and the user needs to make sure it +# exists in SOC Config before the override will function. +BUILTIN_SURICATA_VARS = { + "$HOME_NET", "$EXTERNAL_NET", + "$HTTP_SERVERS", "$DNS_SERVERS", "$SQL_SERVERS", "$SMTP_SERVERS", + "$TELNET_SERVERS", "$AIM_SERVERS", "$DC_SERVERS", "$MODBUS_SERVER", + "$MODBUS_CLIENT", "$ENIP_CLIENT", "$ENIP_SERVER", + "$HTTP_PORTS", "$SHELLCODE_PORTS", "$ORACLE_PORTS", "$SSH_PORTS", + "$FTP_PORTS", "$FILE_DATA_PORTS", +} + +VAR_PATTERN = re.compile(r"\$[A-Z_][A-Z0-9_]*") + +# Canonical valid values, per securityonion-soc/model/detection.go. +SURICATA_OVERRIDE_TYPES = {"suppress", "threshold", "modify"} +SUPPRESS_TRACKS = {"by_src", "by_dst", "by_either"} +THRESHOLD_TRACKS = {"by_src", "by_dst", "by_both"} +THRESHOLD_TYPES = {"limit", "threshold", "both"} + +STALE_WARNING = """\ +WARNING: so-detections-backup does not remove backup files when overrides are +deleted via the Security Onion web UI. As a result, files in the source +directory may represent overrides that were intentionally deleted and should +NOT be re-imported. + +Before continuing, verify that the source directory reflects the overrides you +actually want imported. Remove any files corresponding to overrides you previously deleted. +""" + + +def make_session(auth_file): + with open(auth_file, "r") as f: + for line in f: + if line.startswith("user ="): + creds = line.split("=", 1)[1].strip().replace('"', "") + user, _, password = creds.partition(":") + session = requests.Session() + session.auth = HTTPBasicAuth(user, password) + session.headers.update({"Content-Type": "application/json"}) + session.verify = False + return session + raise RuntimeError(f"Could not find 'user =' line in {auth_file}") + + +def find_detection(session, index, public_id, engine): + query = { + "query": {"bool": {"must": [ + {"term": {"so_detection.publicId": public_id}}, + {"term": {"so_detection.engine": engine}}, + ]}}, + "size": 2, + } + r = session.get(f"{ES_URL}/{index}/_search", json=query) + r.raise_for_status() + hits = r.json().get("hits", {}).get("hits", []) + if not hits: + return None, None, None + if len(hits) > 1: + # Shouldn't happen — publicId is unique per engine — but flag it. + print(f" WARN: {len(hits)} detections matched publicId={public_id} engine={engine}; using first") + hit = hits[0] + existing = hit["_source"].get("so_detection", {}).get("overrides") or [] + return hit["_id"], hit["_index"], existing + + +def update_overrides(session, doc_index, doc_id, overrides): + body = {"doc": {"so_detection": {"overrides": overrides}}} + r = session.post(f"{ES_URL}/{doc_index}/_update/{doc_id}", json=body) + r.raise_for_status() + return r.json() + + +def dedupe_key(override): + """Operational fields only, per Override.Equal() in detection.go. + Excludes timestamps and isEnabled so re-imports don't appear unique.""" + t = override.get("type") + if t == "suppress": + return (t, override.get("track"), override.get("ip")) + if t == "threshold": + return (t, override.get("thresholdType"), override.get("track"), + override.get("count"), override.get("seconds")) + if t == "modify": + return (t, override.get("regex"), override.get("value")) + return (t,) + + +def _validate_suricata_ip(ip): + if not ip: + return "ip cannot be empty" + if ip.startswith("$"): + return None + if ip.startswith("[") and ip.endswith("]"): + for part in ip[1:-1].split(","): + err = _validate_single_ip(part.strip()) + if err: + return f"invalid IP in list: {err}" + return None + return _validate_single_ip(ip) + + +def _validate_single_ip(ip): + try: + if "/" in ip: + ipaddress.ip_network(ip, strict=False) + else: + ipaddress.ip_address(ip) + except ValueError: + return f"invalid IP/CIDR {ip!r}" + return None + + +def validate_override(override, engine): + """Mirror Override.Validate() from securityonion-soc/model/detection.go. + Returns None on success, an error string otherwise.""" + if engine != "suricata": + return None # sigma not yet supported; gated earlier in main() + + t = override.get("type") + if not t: + return "override type is required" + if t not in SURICATA_OVERRIDE_TYPES: + return f"invalid type {t!r}: must be one of {sorted(SURICATA_OVERRIDE_TYPES)}" + + has = {k: override.get(k) is not None for k in + ("regex", "value", "thresholdType", "track", "ip", "count", "seconds", "customFilter")} + + if t == "suppress": + if not has["ip"] or not has["track"]: + return "suppress requires 'ip' and 'track'" + if any(has[k] for k in ("regex", "value", "thresholdType", "count", "seconds", "customFilter")): + return "suppress has unnecessary fields" + if override["track"] not in SUPPRESS_TRACKS: + return f"invalid track {override['track']!r}: must be one of {sorted(SUPPRESS_TRACKS)}" + return _validate_suricata_ip(override["ip"]) + + if t == "threshold": + if not all(has[k] for k in ("thresholdType", "track", "count", "seconds")): + return "threshold requires 'thresholdType', 'track', 'count', 'seconds'" + if any(has[k] for k in ("regex", "value", "customFilter")): + return "threshold has unnecessary fields" + if override["thresholdType"] not in THRESHOLD_TYPES: + return f"invalid thresholdType {override['thresholdType']!r}: must be one of {sorted(THRESHOLD_TYPES)}" + if override["track"] not in THRESHOLD_TRACKS: + return f"invalid track {override['track']!r}: must be one of {sorted(THRESHOLD_TRACKS)}" + if not isinstance(override["count"], int) or override["count"] <= 0: + return f"count must be a positive integer, got {override['count']!r}" + if not isinstance(override["seconds"], int) or override["seconds"] <= 0: + return f"seconds must be a positive integer, got {override['seconds']!r}" + return None + + if t == "modify": + if not has["regex"] or not has["value"]: + return "modify requires 'regex' and 'value'" + if any(has[k] for k in ("thresholdType", "track", "count", "seconds", "customFilter")): + return "modify has unnecessary fields" + try: + re.compile(override["regex"]) + except re.error as e: + return f"invalid regex: {e}" + return None + + +def parse_overrides_file(path): + """Parse a file written by so-detections-backup.py: NDJSON, one override + per line. Returns a list of (override_dict, line_number).""" + overrides = [] + with open(path, "r") as f: + for i, line in enumerate(f, start=1): + line = line.strip() + if not line: + continue + overrides.append((json.loads(line), i)) + return overrides + + +def describe(override): + """Human-readable summary of the operational fields for a given override type.""" + t = override.get("type") + if t == "suppress": + return f"type=suppress track={override.get('track')} ip={override.get('ip')}" + if t == "threshold": + return (f"type=threshold track={override.get('track')} " + f"thresholdType={override.get('thresholdType')} " + f"count={override.get('count')} seconds={override.get('seconds')}") + if t == "modify": + return f"type=modify regex={override.get('regex')!r}" + return f"type={t}" + + +def collect_custom_vars(override): + found = set() + for value in override.values(): + if isinstance(value, str): + for match in VAR_PATTERN.findall(value): + if match not in BUILTIN_SURICATA_VARS: + found.add(match) + return found + + +def parse_args(): + p = argparse.ArgumentParser( + description="Import detection overrides into the so-detection index.", + ) + p.add_argument("--source", "-s", required=True, + help="Source directory containing . override files.") + p.add_argument("--engine", "-e", default="suricata", choices=list(ENGINES.keys()), + help="Detection engine (default: suricata). Sigma not yet supported.") + p.add_argument("--dry-run", "-n", action="store_true", + help="Print what would happen without writing to Elasticsearch.") + p.add_argument("--no-import-note", action="store_true", + help="Do not prepend '[Imported YYYY-MM-DD] ' to the override note.") + p.add_argument("--index", "-i", default=DEFAULT_INDEX, + help=f"Elasticsearch index to update (default: {DEFAULT_INDEX}).") + return p.parse_args() + + +def confirm_proceed(args): + """Show the stale-backup warning. Dry-run prints it and continues. Real + runs require the user typing 'yes' at the prompt.""" + print(STALE_WARNING) + if args.dry_run: + print("(dry-run: no acknowledgement required)\n") + return True + answer = input("Type 'yes' to acknowledge and continue: ").strip().lower() + print() + return answer == "yes" + + +def main(): + args = parse_args() + + if args.engine == "sigma": + print("ERROR: sigma overrides are not yet supported.", file=sys.stderr) + sys.exit(2) + + if not os.path.isdir(args.source): + print(f"ERROR: source directory not found: {args.source}", file=sys.stderr) + sys.exit(1) + + extension = ENGINES[args.engine] + files = sorted(f for f in os.listdir(args.source) if f.endswith(f".{extension}")) + if not files: + print(f"No *.{extension} files found in {args.source}") + sys.exit(0) + + if not confirm_proceed(args): + print("Aborted.") + sys.exit(1) + + session = make_session(AUTH_FILE) + today = datetime.now().strftime("%Y-%m-%d") + note_prefix = "" if args.no_import_note else f"[Imported {today}] " + + counts = {"added": 0, "skipped_dedupe": 0, "skipped_not_found": 0, "invalid": 0, "error": 0} + custom_vars = set() + + mode = "DRY-RUN" if args.dry_run else "IMPORT" + print(f"[{mode}] engine={args.engine} source={args.source} index={args.index}\n") + + for filename in files: + public_id = os.path.splitext(filename)[0] + path = os.path.join(args.source, filename) + print(f"{public_id}:") + + try: + new_overrides = parse_overrides_file(path) + except (json.JSONDecodeError, OSError) as e: + print(f" ERROR: could not parse {filename}: {e}") + counts["error"] += 1 + continue + + if not new_overrides: + print(" SKIP: empty file") + continue + + try: + doc_id, doc_index, existing = find_detection(session, args.index, public_id, args.engine) + except requests.HTTPError as e: + print(f" ERROR: search failed: {e}") + counts["error"] += 1 + continue + + if doc_id is None: + print(f" WARN: no detection found for publicId={public_id} engine={args.engine}; skipping") + counts["skipped_not_found"] += len(new_overrides) + continue + + existing_keys = {dedupe_key(o) for o in existing} + merged = list(existing) + added_this_file = 0 + + for override, line_no in new_overrides: + err = validate_override(override, args.engine) + if err: + print(f" INVALID (line {line_no}): {err}") + counts["invalid"] += 1 + continue + + custom_vars.update(collect_custom_vars(override)) + key = dedupe_key(override) + if key in existing_keys: + print(f" SKIP (line {line_no}): duplicate of existing override [{describe(override)}]") + counts["skipped_dedupe"] += 1 + continue + + if note_prefix: + override = dict(override) + override["note"] = note_prefix + (override.get("note") or "") + + merged.append(override) + existing_keys.add(key) + added_this_file += 1 + print(f" ADD (line {line_no}): {describe(override)}") + + if added_this_file == 0: + continue + + if args.dry_run: + print(f" DRY-RUN: would update {doc_index}/{doc_id} " + f"({len(existing)} existing → {len(merged)} total)") + counts["added"] += added_this_file + continue + + try: + update_overrides(session, doc_index, doc_id, merged) + print(f" UPDATED {doc_index}/{doc_id} ({len(existing)} → {len(merged)})") + counts["added"] += added_this_file + except requests.HTTPError as e: + print(f" ERROR: update failed: {e}") + counts["error"] += 1 + + print() + print("=" * 60) + print(f"Summary ({mode}):") + print(f" Overrides added: {counts['added']}") + print(f" Skipped (already present): {counts['skipped_dedupe']}") + print(f" Skipped (no detection): {counts['skipped_not_found']}") + print(f" Invalid (failed checks): {counts['invalid']}") + print(f" Errors: {counts['error']}") + + if custom_vars and args.engine == "suricata": + print() + print("WARNING: detected custom Suricata variables in imported overrides:") + for v in sorted(custom_vars): + print(f" {v}") + print("If any of these are not already defined in SOC Config (Suricata variables),") + print("you must add them manually before the rules will function correctly.") + + sys.exit(0 if counts["error"] == 0 and counts["invalid"] == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/salt/manager/tools/sbin/so-detections-overrides-import_test.py b/salt/manager/tools/sbin/so-detections-overrides-import_test.py new file mode 100644 index 000000000..ed74e44cb --- /dev/null +++ b/salt/manager/tools/sbin/so-detections-overrides-import_test.py @@ -0,0 +1,406 @@ +# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one +# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at +# https://securityonion.net/license; you may not use this file except in compliance with the +# Elastic License 2.0. + +import json +import os +import tempfile +import unittest +from importlib.machinery import SourceFileLoader +from io import StringIO +from unittest.mock import MagicMock, patch + +# The script has no .py extension, so importlib.import_module won't find it by +# name. SourceFileLoader loads source Python regardless of extension. +HERE = os.path.dirname(os.path.abspath(__file__)) +SCRIPT = os.path.join(HERE, "so-detections-overrides-import") +soi = SourceFileLoader("so_overrides_import", SCRIPT).load_module() + + +class TestValidateSuppress(unittest.TestCase): + def test_valid(self): + self.assertIsNone(soi.validate_override( + {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}, "suricata")) + + def test_valid_var(self): + self.assertIsNone(soi.validate_override( + {"type": "suppress", "track": "by_either", "ip": "$HOME_NET"}, "suricata")) + + def test_valid_cidr(self): + self.assertIsNone(soi.validate_override( + {"type": "suppress", "track": "by_dst", "ip": "10.0.0.0/8"}, "suricata")) + + def test_valid_bracket_list(self): + self.assertIsNone(soi.validate_override( + {"type": "suppress", "track": "by_src", "ip": "[1.2.3.4,10.0.0.0/8]"}, "suricata")) + + def test_missing_ip(self): + err = soi.validate_override({"type": "suppress", "track": "by_src"}, "suricata") + self.assertIn("requires", err) + + def test_missing_track(self): + err = soi.validate_override({"type": "suppress", "ip": "1.2.3.4"}, "suricata") + self.assertIn("requires", err) + + def test_invalid_track(self): + err = soi.validate_override( + {"type": "suppress", "track": "by_both", "ip": "1.2.3.4"}, "suricata") + self.assertIn("invalid track", err) + + def test_invalid_ip(self): + err = soi.validate_override( + {"type": "suppress", "track": "by_src", "ip": "not-an-ip"}, "suricata") + self.assertIn("invalid IP", err) + + def test_unnecessary_field(self): + err = soi.validate_override( + {"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "count": 5}, "suricata") + self.assertIn("unnecessary fields", err) + + +class TestValidateThreshold(unittest.TestCase): + def test_valid(self): + self.assertIsNone(soi.validate_override({ + "type": "threshold", "track": "by_src", + "thresholdType": "limit", "count": 10, "seconds": 60, + }, "suricata")) + + def test_valid_by_both(self): + self.assertIsNone(soi.validate_override({ + "type": "threshold", "track": "by_both", + "thresholdType": "both", "count": 1, "seconds": 1, + }, "suricata")) + + def test_track_by_either_invalid(self): + err = soi.validate_override({ + "type": "threshold", "track": "by_either", + "thresholdType": "limit", "count": 10, "seconds": 60, + }, "suricata") + self.assertIn("invalid track", err) + + def test_invalid_threshold_type(self): + err = soi.validate_override({ + "type": "threshold", "track": "by_src", + "thresholdType": "bogus", "count": 10, "seconds": 60, + }, "suricata") + self.assertIn("invalid thresholdType", err) + + def test_zero_count(self): + err = soi.validate_override({ + "type": "threshold", "track": "by_src", + "thresholdType": "limit", "count": 0, "seconds": 60, + }, "suricata") + self.assertIn("count", err) + + def test_negative_seconds(self): + err = soi.validate_override({ + "type": "threshold", "track": "by_src", + "thresholdType": "limit", "count": 10, "seconds": -1, + }, "suricata") + self.assertIn("seconds", err) + + def test_missing_field(self): + err = soi.validate_override({ + "type": "threshold", "track": "by_src", + "thresholdType": "limit", "count": 10, # missing seconds + }, "suricata") + self.assertIn("requires", err) + + def test_unnecessary_field(self): + err = soi.validate_override({ + "type": "threshold", "track": "by_src", + "thresholdType": "limit", "count": 10, "seconds": 60, + "regex": "foo", + }, "suricata") + self.assertIn("unnecessary fields", err) + + +class TestValidateModify(unittest.TestCase): + def test_valid(self): + self.assertIsNone(soi.validate_override( + {"type": "modify", "regex": r"content:\"foo\"", "value": "content:bar"}, "suricata")) + + def test_invalid_regex(self): + err = soi.validate_override( + {"type": "modify", "regex": "(unbalanced", "value": "x"}, "suricata") + self.assertIn("invalid regex", err) + + def test_missing_value(self): + err = soi.validate_override({"type": "modify", "regex": "x"}, "suricata") + self.assertIn("requires", err) + + def test_unnecessary_field(self): + err = soi.validate_override( + {"type": "modify", "regex": "x", "value": "y", "track": "by_src"}, "suricata") + self.assertIn("unnecessary fields", err) + + +class TestValidateMisc(unittest.TestCase): + def test_unknown_type(self): + err = soi.validate_override({"type": "suppresss", "track": "by_src", "ip": "1.2.3.4"}, "suricata") + self.assertIn("invalid type", err) + + def test_missing_type(self): + err = soi.validate_override({"track": "by_src"}, "suricata") + self.assertIn("type is required", err) + + def test_non_suricata_engine_skipped(self): + # validate_override returns None for non-suricata engines (sigma is gated in main). + self.assertIsNone(soi.validate_override({"type": "anything"}, "sigma")) + + +class TestValidateIP(unittest.TestCase): + def test_plain_ipv4(self): + self.assertIsNone(soi._validate_suricata_ip("1.2.3.4")) + + def test_plain_ipv6(self): + self.assertIsNone(soi._validate_suricata_ip("::1")) + + def test_cidr(self): + self.assertIsNone(soi._validate_suricata_ip("10.0.0.0/8")) + + def test_var(self): + self.assertIsNone(soi._validate_suricata_ip("$CONCOURSEWORKERS")) + + def test_bracket_list(self): + self.assertIsNone(soi._validate_suricata_ip("[1.2.3.4, 10.0.0.0/8]")) + + def test_bracket_list_bad_member(self): + err = soi._validate_suricata_ip("[1.2.3.4,nope]") + self.assertIn("invalid IP in list", err) + + def test_empty(self): + self.assertIn("empty", soi._validate_suricata_ip("")) + + def test_invalid(self): + self.assertIn("invalid", soi._validate_suricata_ip("999.999.999.999")) + + +class TestDedupeKey(unittest.TestCase): + def test_suppress(self): + a = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "count": 99} + b = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"} + # count is irrelevant for suppress dedupe + self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b)) + + def test_suppress_differs_on_ip(self): + a = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"} + b = {"type": "suppress", "track": "by_src", "ip": "5.6.7.8"} + self.assertNotEqual(soi.dedupe_key(a), soi.dedupe_key(b)) + + def test_threshold(self): + a = {"type": "threshold", "track": "by_src", "thresholdType": "limit", + "count": 10, "seconds": 60, "ip": "ignored"} + b = {"type": "threshold", "track": "by_src", "thresholdType": "limit", + "count": 10, "seconds": 60} + self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b)) + + def test_threshold_differs_on_count(self): + a = {"type": "threshold", "track": "by_src", "thresholdType": "limit", + "count": 10, "seconds": 60} + b = {"type": "threshold", "track": "by_src", "thresholdType": "limit", + "count": 20, "seconds": 60} + self.assertNotEqual(soi.dedupe_key(a), soi.dedupe_key(b)) + + def test_modify(self): + a = {"type": "modify", "regex": "x", "value": "y"} + b = {"type": "modify", "regex": "x", "value": "y"} + self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b)) + + +class TestDescribe(unittest.TestCase): + def test_suppress(self): + s = soi.describe({"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}) + self.assertIn("suppress", s) + self.assertIn("by_src", s) + self.assertIn("1.2.3.4", s) + + def test_threshold_includes_count(self): + s = soi.describe({"type": "threshold", "track": "by_src", + "thresholdType": "limit", "count": 10, "seconds": 60}) + self.assertIn("count=10", s) + self.assertIn("seconds=60", s) + + def test_modify(self): + s = soi.describe({"type": "modify", "regex": "foo"}) + self.assertIn("modify", s) + self.assertIn("foo", s) + + +class TestParseOverridesFile(unittest.TestCase): + def _write(self, content): + fd, path = tempfile.mkstemp(suffix=".txt") + os.close(fd) + with open(path, "w") as f: + f.write(content) + self.addCleanup(os.unlink, path) + return path + + def test_single_line(self): + path = self._write('{"type":"suppress","track":"by_src","ip":"1.2.3.4"}') + result = soi.parse_overrides_file(path) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0]["type"], "suppress") + self.assertEqual(result[0][1], 1) + + def test_ndjson(self): + path = self._write( + '{"type":"suppress","track":"by_src","ip":"1.2.3.4"}\n' + '{"type":"suppress","track":"by_dst","ip":"5.6.7.8"}\n' + ) + result = soi.parse_overrides_file(path) + self.assertEqual(len(result), 2) + self.assertEqual(result[1][1], 2) + + def test_empty(self): + path = self._write("") + self.assertEqual(soi.parse_overrides_file(path), []) + + def test_blank_lines_skipped(self): + path = self._write('\n{"type":"suppress","track":"by_src","ip":"1.2.3.4"}\n\n') + result = soi.parse_overrides_file(path) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][1], 2) # line number reflects original position + + def test_invalid_raises(self): + path = self._write("not json") + with self.assertRaises(json.JSONDecodeError): + soi.parse_overrides_file(path) + + +class TestCollectCustomVars(unittest.TestCase): + def test_finds_custom(self): + v = soi.collect_custom_vars({"ip": "$CONCOURSEWORKERS"}) + self.assertEqual(v, {"$CONCOURSEWORKERS"}) + + def test_filters_builtins(self): + v = soi.collect_custom_vars({"ip": "$HOME_NET"}) + self.assertEqual(v, set()) + + def test_mixed(self): + v = soi.collect_custom_vars({"ip": "[$HOME_NET,$MYNET]"}) + self.assertEqual(v, {"$MYNET"}) + + def test_non_string_fields_ignored(self): + v = soi.collect_custom_vars({"count": 10, "isEnabled": True}) + self.assertEqual(v, set()) + + +class TestMakeSession(unittest.TestCase): + def _write(self, content): + fd, path = tempfile.mkstemp() + os.close(fd) + with open(path, "w") as f: + f.write(content) + self.addCleanup(os.unlink, path) + return path + + def test_valid_auth_file(self): + path = self._write('user = "admin:secret"\n') + session = soi.make_session(path) + self.assertEqual(session.auth.username, "admin") + self.assertEqual(session.auth.password, "secret") + self.assertFalse(session.verify) + + def test_missing_user_line(self): + path = self._write("# no user line here\n") + with self.assertRaises(RuntimeError): + soi.make_session(path) + + +class TestFindDetection(unittest.TestCase): + def _session_with_response(self, payload): + session = MagicMock() + response = MagicMock() + response.json.return_value = payload + response.raise_for_status.return_value = None + session.get.return_value = response + return session + + def test_found(self): + session = self._session_with_response({"hits": {"hits": [{ + "_id": "abc", "_index": "so-detection", + "_source": {"so_detection": {"overrides": [{"type": "suppress"}]}}, + }]}}) + doc_id, idx, existing = soi.find_detection(session, "so-detection", "2049201", "suricata") + self.assertEqual(doc_id, "abc") + self.assertEqual(idx, "so-detection") + self.assertEqual(len(existing), 1) + + def test_not_found(self): + session = self._session_with_response({"hits": {"hits": []}}) + doc_id, idx, existing = soi.find_detection(session, "so-detection", "x", "suricata") + self.assertIsNone(doc_id) + self.assertIsNone(idx) + self.assertIsNone(existing) + + def test_no_overrides_field(self): + session = self._session_with_response({"hits": {"hits": [{ + "_id": "abc", "_index": "so-detection", + "_source": {"so_detection": {}}, + }]}}) + _, _, existing = soi.find_detection(session, "so-detection", "x", "suricata") + self.assertEqual(existing, []) + + def test_multiple_hits_warns(self): + session = self._session_with_response({"hits": {"hits": [ + {"_id": "a", "_index": "i", "_source": {"so_detection": {"overrides": []}}}, + {"_id": "b", "_index": "i", "_source": {"so_detection": {"overrides": []}}}, + ]}}) + with patch("sys.stdout", new=StringIO()) as out: + doc_id, _, _ = soi.find_detection(session, "i", "x", "suricata") + self.assertEqual(doc_id, "a") + self.assertIn("WARN", out.getvalue()) + + +class TestUpdateOverrides(unittest.TestCase): + def test_posts_to_update_endpoint(self): + session = MagicMock() + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"result": "updated"} + session.post.return_value = response + + result = soi.update_overrides(session, "so-detection", "abc", [{"type": "suppress"}]) + + self.assertEqual(result, {"result": "updated"}) + url = session.post.call_args[0][0] + self.assertIn("/_update/abc", url) + body = session.post.call_args[1]["json"] + self.assertEqual(body["doc"]["so_detection"]["overrides"], [{"type": "suppress"}]) + + +class TestConfirmProceed(unittest.TestCase): + def test_dry_run_skips_prompt(self): + args = MagicMock(dry_run=True) + with patch("sys.stdout", new=StringIO()): + self.assertTrue(soi.confirm_proceed(args)) + + def test_yes_input(self): + args = MagicMock(dry_run=False) + with patch("sys.stdout", new=StringIO()): + with patch("builtins.input", return_value="yes"): + self.assertTrue(soi.confirm_proceed(args)) + + def test_yes_input_case_insensitive(self): + args = MagicMock(dry_run=False) + with patch("sys.stdout", new=StringIO()): + with patch("builtins.input", return_value="YES"): + self.assertTrue(soi.confirm_proceed(args)) + + def test_no_input_aborts(self): + args = MagicMock(dry_run=False) + with patch("sys.stdout", new=StringIO()): + with patch("builtins.input", return_value="no"): + self.assertFalse(soi.confirm_proceed(args)) + + def test_empty_input_aborts(self): + args = MagicMock(dry_run=False) + with patch("sys.stdout", new=StringIO()): + with patch("builtins.input", return_value=""): + self.assertFalse(soi.confirm_proceed(args)) + + +if __name__ == "__main__": + unittest.main()