Initial commit

This commit is contained in:
Josh Brower
2026-05-12 09:55:06 -04:00
parent 006ac31109
commit 306b0af4d0
2 changed files with 797 additions and 0 deletions
+391
View File
@@ -0,0 +1,391 @@
#!/usr/bin/env python3
# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.
# Imports detection overrides (e.g. from so-detections-backup) into the so-detection
# index. Reads <publicId>.<ext> files (NDJSON, one override per line) from a source
# directory, looks up the matching detection by publicId+engine, validates each
# override against the same rules SOC enforces, dedupes against existing overrides
# (operational fields only), and appends new ones.
import argparse
import ipaddress
import json
import os
import re
import sys
from datetime import datetime
import requests
from requests.auth import HTTPBasicAuth
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
DEFAULT_INDEX = "so-detection"
AUTH_FILE = "/opt/so/conf/elasticsearch/curl.config"
ES_URL = "https://localhost:9200"
# Engines we know how to handle and the file extension the backup script writes.
ENGINES = {
"suricata": "txt",
"sigma": "yaml",
}
# Standard Suricata variables that ship with Security Onion. Anything else
# referenced in an override is "custom" and the user needs to make sure it
# exists in SOC Config before the override will function.
BUILTIN_SURICATA_VARS = {
"$HOME_NET", "$EXTERNAL_NET",
"$HTTP_SERVERS", "$DNS_SERVERS", "$SQL_SERVERS", "$SMTP_SERVERS",
"$TELNET_SERVERS", "$AIM_SERVERS", "$DC_SERVERS", "$MODBUS_SERVER",
"$MODBUS_CLIENT", "$ENIP_CLIENT", "$ENIP_SERVER",
"$HTTP_PORTS", "$SHELLCODE_PORTS", "$ORACLE_PORTS", "$SSH_PORTS",
"$FTP_PORTS", "$FILE_DATA_PORTS",
}
VAR_PATTERN = re.compile(r"\$[A-Z_][A-Z0-9_]*")
# Canonical valid values, per securityonion-soc/model/detection.go.
SURICATA_OVERRIDE_TYPES = {"suppress", "threshold", "modify"}
SUPPRESS_TRACKS = {"by_src", "by_dst", "by_either"}
THRESHOLD_TRACKS = {"by_src", "by_dst", "by_both"}
THRESHOLD_TYPES = {"limit", "threshold", "both"}
STALE_WARNING = """\
WARNING: so-detections-backup does not remove backup files when overrides are
deleted via the Security Onion web UI. As a result, files in the source
directory may represent overrides that were intentionally deleted and should
NOT be re-imported.
Before continuing, verify that the source directory reflects the overrides you
actually want imported. Remove any files corresponding to overrides you previously deleted.
"""
def make_session(auth_file):
with open(auth_file, "r") as f:
for line in f:
if line.startswith("user ="):
creds = line.split("=", 1)[1].strip().replace('"', "")
user, _, password = creds.partition(":")
session = requests.Session()
session.auth = HTTPBasicAuth(user, password)
session.headers.update({"Content-Type": "application/json"})
session.verify = False
return session
raise RuntimeError(f"Could not find 'user =' line in {auth_file}")
def find_detection(session, index, public_id, engine):
query = {
"query": {"bool": {"must": [
{"term": {"so_detection.publicId": public_id}},
{"term": {"so_detection.engine": engine}},
]}},
"size": 2,
}
r = session.get(f"{ES_URL}/{index}/_search", json=query)
r.raise_for_status()
hits = r.json().get("hits", {}).get("hits", [])
if not hits:
return None, None, None
if len(hits) > 1:
# Shouldn't happen — publicId is unique per engine — but flag it.
print(f" WARN: {len(hits)} detections matched publicId={public_id} engine={engine}; using first")
hit = hits[0]
existing = hit["_source"].get("so_detection", {}).get("overrides") or []
return hit["_id"], hit["_index"], existing
def update_overrides(session, doc_index, doc_id, overrides):
body = {"doc": {"so_detection": {"overrides": overrides}}}
r = session.post(f"{ES_URL}/{doc_index}/_update/{doc_id}", json=body)
r.raise_for_status()
return r.json()
def dedupe_key(override):
"""Operational fields only, per Override.Equal() in detection.go.
Excludes timestamps and isEnabled so re-imports don't appear unique."""
t = override.get("type")
if t == "suppress":
return (t, override.get("track"), override.get("ip"))
if t == "threshold":
return (t, override.get("thresholdType"), override.get("track"),
override.get("count"), override.get("seconds"))
if t == "modify":
return (t, override.get("regex"), override.get("value"))
return (t,)
def _validate_suricata_ip(ip):
if not ip:
return "ip cannot be empty"
if ip.startswith("$"):
return None
if ip.startswith("[") and ip.endswith("]"):
for part in ip[1:-1].split(","):
err = _validate_single_ip(part.strip())
if err:
return f"invalid IP in list: {err}"
return None
return _validate_single_ip(ip)
def _validate_single_ip(ip):
try:
if "/" in ip:
ipaddress.ip_network(ip, strict=False)
else:
ipaddress.ip_address(ip)
except ValueError:
return f"invalid IP/CIDR {ip!r}"
return None
def validate_override(override, engine):
"""Mirror Override.Validate() from securityonion-soc/model/detection.go.
Returns None on success, an error string otherwise."""
if engine != "suricata":
return None # sigma not yet supported; gated earlier in main()
t = override.get("type")
if not t:
return "override type is required"
if t not in SURICATA_OVERRIDE_TYPES:
return f"invalid type {t!r}: must be one of {sorted(SURICATA_OVERRIDE_TYPES)}"
has = {k: override.get(k) is not None for k in
("regex", "value", "thresholdType", "track", "ip", "count", "seconds", "customFilter")}
if t == "suppress":
if not has["ip"] or not has["track"]:
return "suppress requires 'ip' and 'track'"
if any(has[k] for k in ("regex", "value", "thresholdType", "count", "seconds", "customFilter")):
return "suppress has unnecessary fields"
if override["track"] not in SUPPRESS_TRACKS:
return f"invalid track {override['track']!r}: must be one of {sorted(SUPPRESS_TRACKS)}"
return _validate_suricata_ip(override["ip"])
if t == "threshold":
if not all(has[k] for k in ("thresholdType", "track", "count", "seconds")):
return "threshold requires 'thresholdType', 'track', 'count', 'seconds'"
if any(has[k] for k in ("regex", "value", "customFilter")):
return "threshold has unnecessary fields"
if override["thresholdType"] not in THRESHOLD_TYPES:
return f"invalid thresholdType {override['thresholdType']!r}: must be one of {sorted(THRESHOLD_TYPES)}"
if override["track"] not in THRESHOLD_TRACKS:
return f"invalid track {override['track']!r}: must be one of {sorted(THRESHOLD_TRACKS)}"
if not isinstance(override["count"], int) or override["count"] <= 0:
return f"count must be a positive integer, got {override['count']!r}"
if not isinstance(override["seconds"], int) or override["seconds"] <= 0:
return f"seconds must be a positive integer, got {override['seconds']!r}"
return None
if t == "modify":
if not has["regex"] or not has["value"]:
return "modify requires 'regex' and 'value'"
if any(has[k] for k in ("thresholdType", "track", "count", "seconds", "customFilter")):
return "modify has unnecessary fields"
try:
re.compile(override["regex"])
except re.error as e:
return f"invalid regex: {e}"
return None
def parse_overrides_file(path):
"""Parse a file written by so-detections-backup.py: NDJSON, one override
per line. Returns a list of (override_dict, line_number)."""
overrides = []
with open(path, "r") as f:
for i, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
overrides.append((json.loads(line), i))
return overrides
def describe(override):
"""Human-readable summary of the operational fields for a given override type."""
t = override.get("type")
if t == "suppress":
return f"type=suppress track={override.get('track')} ip={override.get('ip')}"
if t == "threshold":
return (f"type=threshold track={override.get('track')} "
f"thresholdType={override.get('thresholdType')} "
f"count={override.get('count')} seconds={override.get('seconds')}")
if t == "modify":
return f"type=modify regex={override.get('regex')!r}"
return f"type={t}"
def collect_custom_vars(override):
found = set()
for value in override.values():
if isinstance(value, str):
for match in VAR_PATTERN.findall(value):
if match not in BUILTIN_SURICATA_VARS:
found.add(match)
return found
def parse_args():
p = argparse.ArgumentParser(
description="Import detection overrides into the so-detection index.",
)
p.add_argument("--source", "-s", required=True,
help="Source directory containing <publicId>.<ext> override files.")
p.add_argument("--engine", "-e", default="suricata", choices=list(ENGINES.keys()),
help="Detection engine (default: suricata). Sigma not yet supported.")
p.add_argument("--dry-run", "-n", action="store_true",
help="Print what would happen without writing to Elasticsearch.")
p.add_argument("--no-import-note", action="store_true",
help="Do not prepend '[Imported YYYY-MM-DD] ' to the override note.")
p.add_argument("--index", "-i", default=DEFAULT_INDEX,
help=f"Elasticsearch index to update (default: {DEFAULT_INDEX}).")
return p.parse_args()
def confirm_proceed(args):
"""Show the stale-backup warning. Dry-run prints it and continues. Real
runs require the user typing 'yes' at the prompt."""
print(STALE_WARNING)
if args.dry_run:
print("(dry-run: no acknowledgement required)\n")
return True
answer = input("Type 'yes' to acknowledge and continue: ").strip().lower()
print()
return answer == "yes"
def main():
args = parse_args()
if args.engine == "sigma":
print("ERROR: sigma overrides are not yet supported.", file=sys.stderr)
sys.exit(2)
if not os.path.isdir(args.source):
print(f"ERROR: source directory not found: {args.source}", file=sys.stderr)
sys.exit(1)
extension = ENGINES[args.engine]
files = sorted(f for f in os.listdir(args.source) if f.endswith(f".{extension}"))
if not files:
print(f"No *.{extension} files found in {args.source}")
sys.exit(0)
if not confirm_proceed(args):
print("Aborted.")
sys.exit(1)
session = make_session(AUTH_FILE)
today = datetime.now().strftime("%Y-%m-%d")
note_prefix = "" if args.no_import_note else f"[Imported {today}] "
counts = {"added": 0, "skipped_dedupe": 0, "skipped_not_found": 0, "invalid": 0, "error": 0}
custom_vars = set()
mode = "DRY-RUN" if args.dry_run else "IMPORT"
print(f"[{mode}] engine={args.engine} source={args.source} index={args.index}\n")
for filename in files:
public_id = os.path.splitext(filename)[0]
path = os.path.join(args.source, filename)
print(f"{public_id}:")
try:
new_overrides = parse_overrides_file(path)
except (json.JSONDecodeError, OSError) as e:
print(f" ERROR: could not parse {filename}: {e}")
counts["error"] += 1
continue
if not new_overrides:
print(" SKIP: empty file")
continue
try:
doc_id, doc_index, existing = find_detection(session, args.index, public_id, args.engine)
except requests.HTTPError as e:
print(f" ERROR: search failed: {e}")
counts["error"] += 1
continue
if doc_id is None:
print(f" WARN: no detection found for publicId={public_id} engine={args.engine}; skipping")
counts["skipped_not_found"] += len(new_overrides)
continue
existing_keys = {dedupe_key(o) for o in existing}
merged = list(existing)
added_this_file = 0
for override, line_no in new_overrides:
err = validate_override(override, args.engine)
if err:
print(f" INVALID (line {line_no}): {err}")
counts["invalid"] += 1
continue
custom_vars.update(collect_custom_vars(override))
key = dedupe_key(override)
if key in existing_keys:
print(f" SKIP (line {line_no}): duplicate of existing override [{describe(override)}]")
counts["skipped_dedupe"] += 1
continue
if note_prefix:
override = dict(override)
override["note"] = note_prefix + (override.get("note") or "")
merged.append(override)
existing_keys.add(key)
added_this_file += 1
print(f" ADD (line {line_no}): {describe(override)}")
if added_this_file == 0:
continue
if args.dry_run:
print(f" DRY-RUN: would update {doc_index}/{doc_id} "
f"({len(existing)} existing → {len(merged)} total)")
counts["added"] += added_this_file
continue
try:
update_overrides(session, doc_index, doc_id, merged)
print(f" UPDATED {doc_index}/{doc_id} ({len(existing)} → {len(merged)})")
counts["added"] += added_this_file
except requests.HTTPError as e:
print(f" ERROR: update failed: {e}")
counts["error"] += 1
print()
print("=" * 60)
print(f"Summary ({mode}):")
print(f" Overrides added: {counts['added']}")
print(f" Skipped (already present): {counts['skipped_dedupe']}")
print(f" Skipped (no detection): {counts['skipped_not_found']}")
print(f" Invalid (failed checks): {counts['invalid']}")
print(f" Errors: {counts['error']}")
if custom_vars and args.engine == "suricata":
print()
print("WARNING: detected custom Suricata variables in imported overrides:")
for v in sorted(custom_vars):
print(f" {v}")
print("If any of these are not already defined in SOC Config (Suricata variables),")
print("you must add them manually before the rules will function correctly.")
sys.exit(0 if counts["error"] == 0 and counts["invalid"] == 0 else 1)
if __name__ == "__main__":
main()
@@ -0,0 +1,406 @@
# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.
import json
import os
import tempfile
import unittest
from importlib.machinery import SourceFileLoader
from io import StringIO
from unittest.mock import MagicMock, patch
# The script has no .py extension, so importlib.import_module won't find it by
# name. SourceFileLoader loads source Python regardless of extension.
HERE = os.path.dirname(os.path.abspath(__file__))
SCRIPT = os.path.join(HERE, "so-detections-overrides-import")
soi = SourceFileLoader("so_overrides_import", SCRIPT).load_module()
class TestValidateSuppress(unittest.TestCase):
def test_valid(self):
self.assertIsNone(soi.validate_override(
{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}, "suricata"))
def test_valid_var(self):
self.assertIsNone(soi.validate_override(
{"type": "suppress", "track": "by_either", "ip": "$HOME_NET"}, "suricata"))
def test_valid_cidr(self):
self.assertIsNone(soi.validate_override(
{"type": "suppress", "track": "by_dst", "ip": "10.0.0.0/8"}, "suricata"))
def test_valid_bracket_list(self):
self.assertIsNone(soi.validate_override(
{"type": "suppress", "track": "by_src", "ip": "[1.2.3.4,10.0.0.0/8]"}, "suricata"))
def test_missing_ip(self):
err = soi.validate_override({"type": "suppress", "track": "by_src"}, "suricata")
self.assertIn("requires", err)
def test_missing_track(self):
err = soi.validate_override({"type": "suppress", "ip": "1.2.3.4"}, "suricata")
self.assertIn("requires", err)
def test_invalid_track(self):
err = soi.validate_override(
{"type": "suppress", "track": "by_both", "ip": "1.2.3.4"}, "suricata")
self.assertIn("invalid track", err)
def test_invalid_ip(self):
err = soi.validate_override(
{"type": "suppress", "track": "by_src", "ip": "not-an-ip"}, "suricata")
self.assertIn("invalid IP", err)
def test_unnecessary_field(self):
err = soi.validate_override(
{"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "count": 5}, "suricata")
self.assertIn("unnecessary fields", err)
class TestValidateThreshold(unittest.TestCase):
def test_valid(self):
self.assertIsNone(soi.validate_override({
"type": "threshold", "track": "by_src",
"thresholdType": "limit", "count": 10, "seconds": 60,
}, "suricata"))
def test_valid_by_both(self):
self.assertIsNone(soi.validate_override({
"type": "threshold", "track": "by_both",
"thresholdType": "both", "count": 1, "seconds": 1,
}, "suricata"))
def test_track_by_either_invalid(self):
err = soi.validate_override({
"type": "threshold", "track": "by_either",
"thresholdType": "limit", "count": 10, "seconds": 60,
}, "suricata")
self.assertIn("invalid track", err)
def test_invalid_threshold_type(self):
err = soi.validate_override({
"type": "threshold", "track": "by_src",
"thresholdType": "bogus", "count": 10, "seconds": 60,
}, "suricata")
self.assertIn("invalid thresholdType", err)
def test_zero_count(self):
err = soi.validate_override({
"type": "threshold", "track": "by_src",
"thresholdType": "limit", "count": 0, "seconds": 60,
}, "suricata")
self.assertIn("count", err)
def test_negative_seconds(self):
err = soi.validate_override({
"type": "threshold", "track": "by_src",
"thresholdType": "limit", "count": 10, "seconds": -1,
}, "suricata")
self.assertIn("seconds", err)
def test_missing_field(self):
err = soi.validate_override({
"type": "threshold", "track": "by_src",
"thresholdType": "limit", "count": 10, # missing seconds
}, "suricata")
self.assertIn("requires", err)
def test_unnecessary_field(self):
err = soi.validate_override({
"type": "threshold", "track": "by_src",
"thresholdType": "limit", "count": 10, "seconds": 60,
"regex": "foo",
}, "suricata")
self.assertIn("unnecessary fields", err)
class TestValidateModify(unittest.TestCase):
def test_valid(self):
self.assertIsNone(soi.validate_override(
{"type": "modify", "regex": r"content:\"foo\"", "value": "content:bar"}, "suricata"))
def test_invalid_regex(self):
err = soi.validate_override(
{"type": "modify", "regex": "(unbalanced", "value": "x"}, "suricata")
self.assertIn("invalid regex", err)
def test_missing_value(self):
err = soi.validate_override({"type": "modify", "regex": "x"}, "suricata")
self.assertIn("requires", err)
def test_unnecessary_field(self):
err = soi.validate_override(
{"type": "modify", "regex": "x", "value": "y", "track": "by_src"}, "suricata")
self.assertIn("unnecessary fields", err)
class TestValidateMisc(unittest.TestCase):
def test_unknown_type(self):
err = soi.validate_override({"type": "suppresss", "track": "by_src", "ip": "1.2.3.4"}, "suricata")
self.assertIn("invalid type", err)
def test_missing_type(self):
err = soi.validate_override({"track": "by_src"}, "suricata")
self.assertIn("type is required", err)
def test_non_suricata_engine_skipped(self):
# validate_override returns None for non-suricata engines (sigma is gated in main).
self.assertIsNone(soi.validate_override({"type": "anything"}, "sigma"))
class TestValidateIP(unittest.TestCase):
def test_plain_ipv4(self):
self.assertIsNone(soi._validate_suricata_ip("1.2.3.4"))
def test_plain_ipv6(self):
self.assertIsNone(soi._validate_suricata_ip("::1"))
def test_cidr(self):
self.assertIsNone(soi._validate_suricata_ip("10.0.0.0/8"))
def test_var(self):
self.assertIsNone(soi._validate_suricata_ip("$CONCOURSEWORKERS"))
def test_bracket_list(self):
self.assertIsNone(soi._validate_suricata_ip("[1.2.3.4, 10.0.0.0/8]"))
def test_bracket_list_bad_member(self):
err = soi._validate_suricata_ip("[1.2.3.4,nope]")
self.assertIn("invalid IP in list", err)
def test_empty(self):
self.assertIn("empty", soi._validate_suricata_ip(""))
def test_invalid(self):
self.assertIn("invalid", soi._validate_suricata_ip("999.999.999.999"))
class TestDedupeKey(unittest.TestCase):
def test_suppress(self):
a = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "count": 99}
b = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}
# count is irrelevant for suppress dedupe
self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b))
def test_suppress_differs_on_ip(self):
a = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}
b = {"type": "suppress", "track": "by_src", "ip": "5.6.7.8"}
self.assertNotEqual(soi.dedupe_key(a), soi.dedupe_key(b))
def test_threshold(self):
a = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
"count": 10, "seconds": 60, "ip": "ignored"}
b = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
"count": 10, "seconds": 60}
self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b))
def test_threshold_differs_on_count(self):
a = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
"count": 10, "seconds": 60}
b = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
"count": 20, "seconds": 60}
self.assertNotEqual(soi.dedupe_key(a), soi.dedupe_key(b))
def test_modify(self):
a = {"type": "modify", "regex": "x", "value": "y"}
b = {"type": "modify", "regex": "x", "value": "y"}
self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b))
class TestDescribe(unittest.TestCase):
def test_suppress(self):
s = soi.describe({"type": "suppress", "track": "by_src", "ip": "1.2.3.4"})
self.assertIn("suppress", s)
self.assertIn("by_src", s)
self.assertIn("1.2.3.4", s)
def test_threshold_includes_count(self):
s = soi.describe({"type": "threshold", "track": "by_src",
"thresholdType": "limit", "count": 10, "seconds": 60})
self.assertIn("count=10", s)
self.assertIn("seconds=60", s)
def test_modify(self):
s = soi.describe({"type": "modify", "regex": "foo"})
self.assertIn("modify", s)
self.assertIn("foo", s)
class TestParseOverridesFile(unittest.TestCase):
def _write(self, content):
fd, path = tempfile.mkstemp(suffix=".txt")
os.close(fd)
with open(path, "w") as f:
f.write(content)
self.addCleanup(os.unlink, path)
return path
def test_single_line(self):
path = self._write('{"type":"suppress","track":"by_src","ip":"1.2.3.4"}')
result = soi.parse_overrides_file(path)
self.assertEqual(len(result), 1)
self.assertEqual(result[0][0]["type"], "suppress")
self.assertEqual(result[0][1], 1)
def test_ndjson(self):
path = self._write(
'{"type":"suppress","track":"by_src","ip":"1.2.3.4"}\n'
'{"type":"suppress","track":"by_dst","ip":"5.6.7.8"}\n'
)
result = soi.parse_overrides_file(path)
self.assertEqual(len(result), 2)
self.assertEqual(result[1][1], 2)
def test_empty(self):
path = self._write("")
self.assertEqual(soi.parse_overrides_file(path), [])
def test_blank_lines_skipped(self):
path = self._write('\n{"type":"suppress","track":"by_src","ip":"1.2.3.4"}\n\n')
result = soi.parse_overrides_file(path)
self.assertEqual(len(result), 1)
self.assertEqual(result[0][1], 2) # line number reflects original position
def test_invalid_raises(self):
path = self._write("not json")
with self.assertRaises(json.JSONDecodeError):
soi.parse_overrides_file(path)
class TestCollectCustomVars(unittest.TestCase):
def test_finds_custom(self):
v = soi.collect_custom_vars({"ip": "$CONCOURSEWORKERS"})
self.assertEqual(v, {"$CONCOURSEWORKERS"})
def test_filters_builtins(self):
v = soi.collect_custom_vars({"ip": "$HOME_NET"})
self.assertEqual(v, set())
def test_mixed(self):
v = soi.collect_custom_vars({"ip": "[$HOME_NET,$MYNET]"})
self.assertEqual(v, {"$MYNET"})
def test_non_string_fields_ignored(self):
v = soi.collect_custom_vars({"count": 10, "isEnabled": True})
self.assertEqual(v, set())
class TestMakeSession(unittest.TestCase):
def _write(self, content):
fd, path = tempfile.mkstemp()
os.close(fd)
with open(path, "w") as f:
f.write(content)
self.addCleanup(os.unlink, path)
return path
def test_valid_auth_file(self):
path = self._write('user = "admin:secret"\n')
session = soi.make_session(path)
self.assertEqual(session.auth.username, "admin")
self.assertEqual(session.auth.password, "secret")
self.assertFalse(session.verify)
def test_missing_user_line(self):
path = self._write("# no user line here\n")
with self.assertRaises(RuntimeError):
soi.make_session(path)
class TestFindDetection(unittest.TestCase):
def _session_with_response(self, payload):
session = MagicMock()
response = MagicMock()
response.json.return_value = payload
response.raise_for_status.return_value = None
session.get.return_value = response
return session
def test_found(self):
session = self._session_with_response({"hits": {"hits": [{
"_id": "abc", "_index": "so-detection",
"_source": {"so_detection": {"overrides": [{"type": "suppress"}]}},
}]}})
doc_id, idx, existing = soi.find_detection(session, "so-detection", "2049201", "suricata")
self.assertEqual(doc_id, "abc")
self.assertEqual(idx, "so-detection")
self.assertEqual(len(existing), 1)
def test_not_found(self):
session = self._session_with_response({"hits": {"hits": []}})
doc_id, idx, existing = soi.find_detection(session, "so-detection", "x", "suricata")
self.assertIsNone(doc_id)
self.assertIsNone(idx)
self.assertIsNone(existing)
def test_no_overrides_field(self):
session = self._session_with_response({"hits": {"hits": [{
"_id": "abc", "_index": "so-detection",
"_source": {"so_detection": {}},
}]}})
_, _, existing = soi.find_detection(session, "so-detection", "x", "suricata")
self.assertEqual(existing, [])
def test_multiple_hits_warns(self):
session = self._session_with_response({"hits": {"hits": [
{"_id": "a", "_index": "i", "_source": {"so_detection": {"overrides": []}}},
{"_id": "b", "_index": "i", "_source": {"so_detection": {"overrides": []}}},
]}})
with patch("sys.stdout", new=StringIO()) as out:
doc_id, _, _ = soi.find_detection(session, "i", "x", "suricata")
self.assertEqual(doc_id, "a")
self.assertIn("WARN", out.getvalue())
class TestUpdateOverrides(unittest.TestCase):
def test_posts_to_update_endpoint(self):
session = MagicMock()
response = MagicMock()
response.raise_for_status.return_value = None
response.json.return_value = {"result": "updated"}
session.post.return_value = response
result = soi.update_overrides(session, "so-detection", "abc", [{"type": "suppress"}])
self.assertEqual(result, {"result": "updated"})
url = session.post.call_args[0][0]
self.assertIn("/_update/abc", url)
body = session.post.call_args[1]["json"]
self.assertEqual(body["doc"]["so_detection"]["overrides"], [{"type": "suppress"}])
class TestConfirmProceed(unittest.TestCase):
def test_dry_run_skips_prompt(self):
args = MagicMock(dry_run=True)
with patch("sys.stdout", new=StringIO()):
self.assertTrue(soi.confirm_proceed(args))
def test_yes_input(self):
args = MagicMock(dry_run=False)
with patch("sys.stdout", new=StringIO()):
with patch("builtins.input", return_value="yes"):
self.assertTrue(soi.confirm_proceed(args))
def test_yes_input_case_insensitive(self):
args = MagicMock(dry_run=False)
with patch("sys.stdout", new=StringIO()):
with patch("builtins.input", return_value="YES"):
self.assertTrue(soi.confirm_proceed(args))
def test_no_input_aborts(self):
args = MagicMock(dry_run=False)
with patch("sys.stdout", new=StringIO()):
with patch("builtins.input", return_value="no"):
self.assertFalse(soi.confirm_proceed(args))
def test_empty_input_aborts(self):
args = MagicMock(dry_run=False)
with patch("sys.stdout", new=StringIO()):
with patch("builtins.input", return_value=""):
self.assertFalse(soi.confirm_proceed(args))
if __name__ == "__main__":
unittest.main()