diff --git a/salt/manager/tools/sbin/so-detections-overrides-import b/salt/manager/tools/sbin/so-detections-overrides-import index 1f32bf04a..e1cad3ac0 100755 --- a/salt/manager/tools/sbin/so-detections-overrides-import +++ b/salt/manager/tools/sbin/so-detections-overrides-import @@ -32,7 +32,6 @@ ES_URL = "https://localhost:9200" # Engines we know how to handle and the file extension the backup script writes. ENGINES = { "suricata": "txt", - "sigma": "yaml", } # Standard Suricata variables that ship with Security Onion. Anything else @@ -119,7 +118,6 @@ def dedupe_key(override): override.get("count"), override.get("seconds")) if t == "modify": return (t, override.get("regex"), override.get("value")) - return (t,) def _validate_suricata_ip(ip): @@ -150,9 +148,6 @@ def _validate_single_ip(ip): def validate_override(override, engine): """Mirror Override.Validate() from securityonion-soc/model/detection.go. Returns None on success, an error string otherwise.""" - if engine != "suricata": - return None # sigma not yet supported; gated earlier in main() - t = override.get("type") if not t: return "override type is required" @@ -222,7 +217,6 @@ def describe(override): f"count={override.get('count')} seconds={override.get('seconds')}") if t == "modify": return f"type=modify regex={override.get('regex')!r}" - return f"type={t}" def collect_custom_vars(override): @@ -242,7 +236,7 @@ def parse_args(): p.add_argument("--source", "-s", required=True, help="Source directory containing . override files.") p.add_argument("--engine", "-e", default="suricata", choices=list(ENGINES.keys()), - help="Detection engine (default: suricata). Sigma not yet supported.") + help="Detection engine (default: suricata).") p.add_argument("--dry-run", "-n", action="store_true", help="Print what would happen without writing to Elasticsearch.") p.add_argument("--no-import-note", action="store_true", @@ -267,10 +261,6 @@ def confirm_proceed(args): def main(): args = parse_args() - if args.engine == "sigma": - print("ERROR: sigma overrides are not yet supported.", file=sys.stderr) - sys.exit(2) - if not os.path.isdir(args.source): print(f"ERROR: source directory not found: {args.source}", file=sys.stderr) sys.exit(1) @@ -376,7 +366,7 @@ def main(): print(f" Invalid (failed checks): {counts['invalid']}") print(f" Errors: {counts['error']}") - if custom_vars and args.engine == "suricata": + if custom_vars: print() print("WARNING: detected custom Suricata variables in imported overrides:") for v in sorted(custom_vars): diff --git a/salt/manager/tools/sbin/so-detections-overrides-import_test.py b/salt/manager/tools/sbin/so-detections-overrides-import_test.py index ed74e44cb..5f5361ea4 100644 --- a/salt/manager/tools/sbin/so-detections-overrides-import_test.py +++ b/salt/manager/tools/sbin/so-detections-overrides-import_test.py @@ -3,19 +3,28 @@ # https://securityonion.net/license; you may not use this file except in compliance with the # Elastic License 2.0. +import importlib.util import json import os +import shutil +import sys import tempfile import unittest from importlib.machinery import SourceFileLoader from io import StringIO from unittest.mock import MagicMock, patch -# The script has no .py extension, so importlib.import_module won't find it by -# name. SourceFileLoader loads source Python regardless of extension. +import requests + +# The script has no .py extension; spec_from_file_location can't auto-detect a +# loader, so we hand it a SourceFileLoader explicitly. (load_module() is +# deprecated in 3.14 and slated for removal in 3.15.) HERE = os.path.dirname(os.path.abspath(__file__)) SCRIPT = os.path.join(HERE, "so-detections-overrides-import") -soi = SourceFileLoader("so_overrides_import", SCRIPT).load_module() +_loader = SourceFileLoader("so_overrides_import", SCRIPT) +_spec = importlib.util.spec_from_loader("so_overrides_import", _loader) +soi = importlib.util.module_from_spec(_spec) +_loader.exec_module(soi) class TestValidateSuppress(unittest.TestCase): @@ -145,10 +154,6 @@ class TestValidateMisc(unittest.TestCase): err = soi.validate_override({"track": "by_src"}, "suricata") self.assertIn("type is required", err) - def test_non_suricata_engine_skipped(self): - # validate_override returns None for non-suricata engines (sigma is gated in main). - self.assertIsNone(soi.validate_override({"type": "anything"}, "sigma")) - class TestValidateIP(unittest.TestCase): def test_plain_ipv4(self): @@ -402,5 +407,182 @@ class TestConfirmProceed(unittest.TestCase): self.assertFalse(soi.confirm_proceed(args)) +class TestParseArgs(unittest.TestCase): + def test_defaults(self): + with patch.object(sys, "argv", ["cmd", "--source", "/some/path"]): + args = soi.parse_args() + self.assertEqual(args.source, "/some/path") + self.assertEqual(args.engine, "suricata") + self.assertFalse(args.dry_run) + self.assertFalse(args.no_import_note) + self.assertEqual(args.index, soi.DEFAULT_INDEX) + + def test_all_options(self): + argv = ["cmd", "-s", "/x", "-e", "suricata", "-n", + "--no-import-note", "-i", "alt-index"] + with patch.object(sys, "argv", argv): + args = soi.parse_args() + self.assertEqual(args.source, "/x") + self.assertTrue(args.dry_run) + self.assertTrue(args.no_import_note) + self.assertEqual(args.index, "alt-index") + + +class TestMain(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmpdir, ignore_errors=True) + # Stub make_session so tests don't need /opt/so/conf/elasticsearch/curl.config. + p = patch.object(soi, "make_session", return_value=MagicMock()) + p.start() + self.addCleanup(p.stop) + + def _write_file(self, public_id, overrides, ext="txt"): + """Write an NDJSON override file. Entries may be dicts or raw strings (for malformed input).""" + path = os.path.join(self.tmpdir, f"{public_id}.{ext}") + with open(path, "w") as f: + for o in overrides: + f.write(o if isinstance(o, str) else json.dumps(o)) + f.write("\n") + return path + + def _run_main(self, *extra_argv, input_response="yes"): + """Run main() with stdout/stderr captured and input mocked. Returns (stdout, stderr, exit_code).""" + argv = ["cmd", "--source", self.tmpdir, *extra_argv] + out, err = StringIO(), StringIO() + with patch.object(sys, "argv", argv), \ + patch("sys.stdout", new=out), \ + patch("sys.stderr", new=err), \ + patch("builtins.input", return_value=input_response): + with self.assertRaises(SystemExit) as cm: + soi.main() + return out.getvalue(), err.getvalue(), cm.exception.code + + def test_source_dir_missing(self): + argv = ["cmd", "--source", "/no/such/path/here"] + err = StringIO() + with patch.object(sys, "argv", argv), patch("sys.stderr", new=err): + with self.assertRaises(SystemExit) as cm: + soi.main() + self.assertEqual(cm.exception.code, 1) + self.assertIn("source directory not found", err.getvalue()) + + def test_no_files_found(self): + out, _, code = self._run_main() + self.assertEqual(code, 0) + self.assertIn("No *.txt files found", out) + + def test_user_aborts(self): + self._write_file("1001", [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}]) + out, _, code = self._run_main(input_response="no") + self.assertEqual(code, 1) + self.assertIn("Aborted", out) + + def test_parse_error_increments_error(self): + # Malformed JSON line — parse_overrides_file raises JSONDecodeError. + self._write_file("1002", ["not json"]) + out, _, code = self._run_main("--dry-run") + self.assertEqual(code, 1) # invalid+error → non-zero + self.assertIn("could not parse", out) + self.assertIn("Errors: 1", out) + + def test_empty_file_skipped(self): + # Blank lines only — parse_overrides_file returns []; main reports "empty file" and continues. + path = os.path.join(self.tmpdir, "1003.txt") + with open(path, "w") as f: + f.write("\n\n") + out, _, code = self._run_main("--dry-run") + self.assertEqual(code, 0) + self.assertIn("empty file", out) + + @patch.object(soi, "find_detection") + def test_search_http_error(self, mock_find): + mock_find.side_effect = requests.HTTPError("boom") + self._write_file("1004", [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}]) + out, _, code = self._run_main("--dry-run") + self.assertEqual(code, 1) + self.assertIn("search failed", out) + + @patch.object(soi, "find_detection") + def test_no_detection_found(self, mock_find): + mock_find.return_value = (None, None, None) + self._write_file("1005", [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}]) + out, _, code = self._run_main("--dry-run") + self.assertEqual(code, 0) + self.assertIn("no detection found", out) + self.assertIn("Skipped (no detection): 1", out) + + @patch.object(soi, "find_detection") + def test_all_duplicates_no_update(self, mock_find): + existing = [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}] + mock_find.return_value = ("doc1", "so-detection", existing) + self._write_file("1006", [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}]) + out, _, code = self._run_main("--dry-run") + self.assertEqual(code, 0) + self.assertIn("SKIP", out) + self.assertNotIn("DRY-RUN: would update", out) # added_this_file == 0 branch + + @patch.object(soi, "update_overrides") + @patch.object(soi, "find_detection") + def test_happy_path_full(self, mock_find, mock_update): + # Exercises: ADD, dedupe SKIP, INVALID, note prefix, UPDATE, custom-vars warning, exit=1 (invalid present) + existing = [{"type": "suppress", "track": "by_src", "ip": "9.9.9.9"}] + mock_find.return_value = ("doc1", "so-detection", existing) + mock_update.return_value = {"result": "updated"} + self._write_file("1007", [ + {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}, # ADD + {"type": "suppress", "track": "by_src", "ip": "9.9.9.9"}, # SKIP (dupe of existing) + {"type": "suppress", "track": "bogus", "ip": "1.2.3.4"}, # INVALID + {"type": "suppress", "track": "by_src", "ip": "$CONCOURSEWORKERS"}, # ADD + custom var + ]) + out, _, code = self._run_main() + self.assertEqual(code, 1) # one invalid -> non-zero + + mock_update.assert_called_once() + merged = mock_update.call_args[0][3] + self.assertEqual(len(merged), 3) # 1 existing + 2 new + new_notes = [o.get("note", "") for o in merged if o.get("ip") in ("1.2.3.4", "$CONCOURSEWORKERS")] + self.assertTrue(all(n.startswith("[Imported ") for n in new_notes)) + + self.assertIn("ADD", out) + self.assertIn("SKIP", out) + self.assertIn("INVALID", out) + self.assertIn("UPDATED", out) + self.assertIn("$CONCOURSEWORKERS", out) + + @patch.object(soi, "update_overrides") + @patch.object(soi, "find_detection") + def test_no_import_note_preserves_note(self, mock_find, mock_update): + mock_find.return_value = ("doc1", "so-detection", []) + mock_update.return_value = {"result": "updated"} + self._write_file("1008", [ + {"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "note": "original"}, + ]) + _, _, code = self._run_main("--no-import-note") + self.assertEqual(code, 0) + merged = mock_update.call_args[0][3] + self.assertEqual(merged[0]["note"], "original") # no prefix applied + + @patch.object(soi, "find_detection") + def test_dry_run_skips_update(self, mock_find): + mock_find.return_value = ("doc1", "so-detection", []) + self._write_file("1009", [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}]) + with patch.object(soi, "update_overrides") as mock_update: + out, _, code = self._run_main("--dry-run") + self.assertEqual(code, 0) + mock_update.assert_not_called() + self.assertIn("DRY-RUN: would update", out) + + @patch.object(soi, "update_overrides") + @patch.object(soi, "find_detection") + def test_update_http_error(self, mock_find, mock_update): + mock_find.return_value = ("doc1", "so-detection", []) + mock_update.side_effect = requests.HTTPError("nope") + self._write_file("1010", [{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}]) + out, _, code = self._run_main() + self.assertEqual(code, 1) + self.assertIn("update failed", out) + + if __name__ == "__main__": unittest.main()