mirror of
https://github.com/Security-Onion-Solutions/securityonion.git
synced 2026-05-18 01:01:42 +02:00
Initial commit
This commit is contained in:
+391
@@ -0,0 +1,391 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
|
||||
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
|
||||
# https://securityonion.net/license; you may not use this file except in compliance with the
|
||||
# Elastic License 2.0.
|
||||
|
||||
# Imports detection overrides (e.g. from so-detections-backup) into the so-detection
|
||||
# index. Reads <publicId>.<ext> files (NDJSON, one override per line) from a source
|
||||
# directory, looks up the matching detection by publicId+engine, validates each
|
||||
# override against the same rules SOC enforces, dedupes against existing overrides
|
||||
# (operational fields only), and appends new ones.
|
||||
|
||||
import argparse
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
DEFAULT_INDEX = "so-detection"
|
||||
AUTH_FILE = "/opt/so/conf/elasticsearch/curl.config"
|
||||
ES_URL = "https://localhost:9200"
|
||||
|
||||
# Engines we know how to handle and the file extension the backup script writes.
|
||||
ENGINES = {
|
||||
"suricata": "txt",
|
||||
"sigma": "yaml",
|
||||
}
|
||||
|
||||
# Standard Suricata variables that ship with Security Onion. Anything else
|
||||
# referenced in an override is "custom" and the user needs to make sure it
|
||||
# exists in SOC Config before the override will function.
|
||||
BUILTIN_SURICATA_VARS = {
|
||||
"$HOME_NET", "$EXTERNAL_NET",
|
||||
"$HTTP_SERVERS", "$DNS_SERVERS", "$SQL_SERVERS", "$SMTP_SERVERS",
|
||||
"$TELNET_SERVERS", "$AIM_SERVERS", "$DC_SERVERS", "$MODBUS_SERVER",
|
||||
"$MODBUS_CLIENT", "$ENIP_CLIENT", "$ENIP_SERVER",
|
||||
"$HTTP_PORTS", "$SHELLCODE_PORTS", "$ORACLE_PORTS", "$SSH_PORTS",
|
||||
"$FTP_PORTS", "$FILE_DATA_PORTS",
|
||||
}
|
||||
|
||||
VAR_PATTERN = re.compile(r"\$[A-Z_][A-Z0-9_]*")
|
||||
|
||||
# Canonical valid values, per securityonion-soc/model/detection.go.
|
||||
SURICATA_OVERRIDE_TYPES = {"suppress", "threshold", "modify"}
|
||||
SUPPRESS_TRACKS = {"by_src", "by_dst", "by_either"}
|
||||
THRESHOLD_TRACKS = {"by_src", "by_dst", "by_both"}
|
||||
THRESHOLD_TYPES = {"limit", "threshold", "both"}
|
||||
|
||||
STALE_WARNING = """\
|
||||
WARNING: so-detections-backup does not remove backup files when overrides are
|
||||
deleted via the Security Onion web UI. As a result, files in the source
|
||||
directory may represent overrides that were intentionally deleted and should
|
||||
NOT be re-imported.
|
||||
|
||||
Before continuing, verify that the source directory reflects the overrides you
|
||||
actually want imported. Remove any files corresponding to overrides you previously deleted.
|
||||
"""
|
||||
|
||||
|
||||
def make_session(auth_file):
|
||||
with open(auth_file, "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("user ="):
|
||||
creds = line.split("=", 1)[1].strip().replace('"', "")
|
||||
user, _, password = creds.partition(":")
|
||||
session = requests.Session()
|
||||
session.auth = HTTPBasicAuth(user, password)
|
||||
session.headers.update({"Content-Type": "application/json"})
|
||||
session.verify = False
|
||||
return session
|
||||
raise RuntimeError(f"Could not find 'user =' line in {auth_file}")
|
||||
|
||||
|
||||
def find_detection(session, index, public_id, engine):
|
||||
query = {
|
||||
"query": {"bool": {"must": [
|
||||
{"term": {"so_detection.publicId": public_id}},
|
||||
{"term": {"so_detection.engine": engine}},
|
||||
]}},
|
||||
"size": 2,
|
||||
}
|
||||
r = session.get(f"{ES_URL}/{index}/_search", json=query)
|
||||
r.raise_for_status()
|
||||
hits = r.json().get("hits", {}).get("hits", [])
|
||||
if not hits:
|
||||
return None, None, None
|
||||
if len(hits) > 1:
|
||||
# Shouldn't happen — publicId is unique per engine — but flag it.
|
||||
print(f" WARN: {len(hits)} detections matched publicId={public_id} engine={engine}; using first")
|
||||
hit = hits[0]
|
||||
existing = hit["_source"].get("so_detection", {}).get("overrides") or []
|
||||
return hit["_id"], hit["_index"], existing
|
||||
|
||||
|
||||
def update_overrides(session, doc_index, doc_id, overrides):
|
||||
body = {"doc": {"so_detection": {"overrides": overrides}}}
|
||||
r = session.post(f"{ES_URL}/{doc_index}/_update/{doc_id}", json=body)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def dedupe_key(override):
|
||||
"""Operational fields only, per Override.Equal() in detection.go.
|
||||
Excludes timestamps and isEnabled so re-imports don't appear unique."""
|
||||
t = override.get("type")
|
||||
if t == "suppress":
|
||||
return (t, override.get("track"), override.get("ip"))
|
||||
if t == "threshold":
|
||||
return (t, override.get("thresholdType"), override.get("track"),
|
||||
override.get("count"), override.get("seconds"))
|
||||
if t == "modify":
|
||||
return (t, override.get("regex"), override.get("value"))
|
||||
return (t,)
|
||||
|
||||
|
||||
def _validate_suricata_ip(ip):
|
||||
if not ip:
|
||||
return "ip cannot be empty"
|
||||
if ip.startswith("$"):
|
||||
return None
|
||||
if ip.startswith("[") and ip.endswith("]"):
|
||||
for part in ip[1:-1].split(","):
|
||||
err = _validate_single_ip(part.strip())
|
||||
if err:
|
||||
return f"invalid IP in list: {err}"
|
||||
return None
|
||||
return _validate_single_ip(ip)
|
||||
|
||||
|
||||
def _validate_single_ip(ip):
|
||||
try:
|
||||
if "/" in ip:
|
||||
ipaddress.ip_network(ip, strict=False)
|
||||
else:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError:
|
||||
return f"invalid IP/CIDR {ip!r}"
|
||||
return None
|
||||
|
||||
|
||||
def validate_override(override, engine):
|
||||
"""Mirror Override.Validate() from securityonion-soc/model/detection.go.
|
||||
Returns None on success, an error string otherwise."""
|
||||
if engine != "suricata":
|
||||
return None # sigma not yet supported; gated earlier in main()
|
||||
|
||||
t = override.get("type")
|
||||
if not t:
|
||||
return "override type is required"
|
||||
if t not in SURICATA_OVERRIDE_TYPES:
|
||||
return f"invalid type {t!r}: must be one of {sorted(SURICATA_OVERRIDE_TYPES)}"
|
||||
|
||||
has = {k: override.get(k) is not None for k in
|
||||
("regex", "value", "thresholdType", "track", "ip", "count", "seconds", "customFilter")}
|
||||
|
||||
if t == "suppress":
|
||||
if not has["ip"] or not has["track"]:
|
||||
return "suppress requires 'ip' and 'track'"
|
||||
if any(has[k] for k in ("regex", "value", "thresholdType", "count", "seconds", "customFilter")):
|
||||
return "suppress has unnecessary fields"
|
||||
if override["track"] not in SUPPRESS_TRACKS:
|
||||
return f"invalid track {override['track']!r}: must be one of {sorted(SUPPRESS_TRACKS)}"
|
||||
return _validate_suricata_ip(override["ip"])
|
||||
|
||||
if t == "threshold":
|
||||
if not all(has[k] for k in ("thresholdType", "track", "count", "seconds")):
|
||||
return "threshold requires 'thresholdType', 'track', 'count', 'seconds'"
|
||||
if any(has[k] for k in ("regex", "value", "customFilter")):
|
||||
return "threshold has unnecessary fields"
|
||||
if override["thresholdType"] not in THRESHOLD_TYPES:
|
||||
return f"invalid thresholdType {override['thresholdType']!r}: must be one of {sorted(THRESHOLD_TYPES)}"
|
||||
if override["track"] not in THRESHOLD_TRACKS:
|
||||
return f"invalid track {override['track']!r}: must be one of {sorted(THRESHOLD_TRACKS)}"
|
||||
if not isinstance(override["count"], int) or override["count"] <= 0:
|
||||
return f"count must be a positive integer, got {override['count']!r}"
|
||||
if not isinstance(override["seconds"], int) or override["seconds"] <= 0:
|
||||
return f"seconds must be a positive integer, got {override['seconds']!r}"
|
||||
return None
|
||||
|
||||
if t == "modify":
|
||||
if not has["regex"] or not has["value"]:
|
||||
return "modify requires 'regex' and 'value'"
|
||||
if any(has[k] for k in ("thresholdType", "track", "count", "seconds", "customFilter")):
|
||||
return "modify has unnecessary fields"
|
||||
try:
|
||||
re.compile(override["regex"])
|
||||
except re.error as e:
|
||||
return f"invalid regex: {e}"
|
||||
return None
|
||||
|
||||
|
||||
def parse_overrides_file(path):
|
||||
"""Parse a file written by so-detections-backup.py: NDJSON, one override
|
||||
per line. Returns a list of (override_dict, line_number)."""
|
||||
overrides = []
|
||||
with open(path, "r") as f:
|
||||
for i, line in enumerate(f, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
overrides.append((json.loads(line), i))
|
||||
return overrides
|
||||
|
||||
|
||||
def describe(override):
|
||||
"""Human-readable summary of the operational fields for a given override type."""
|
||||
t = override.get("type")
|
||||
if t == "suppress":
|
||||
return f"type=suppress track={override.get('track')} ip={override.get('ip')}"
|
||||
if t == "threshold":
|
||||
return (f"type=threshold track={override.get('track')} "
|
||||
f"thresholdType={override.get('thresholdType')} "
|
||||
f"count={override.get('count')} seconds={override.get('seconds')}")
|
||||
if t == "modify":
|
||||
return f"type=modify regex={override.get('regex')!r}"
|
||||
return f"type={t}"
|
||||
|
||||
|
||||
def collect_custom_vars(override):
|
||||
found = set()
|
||||
for value in override.values():
|
||||
if isinstance(value, str):
|
||||
for match in VAR_PATTERN.findall(value):
|
||||
if match not in BUILTIN_SURICATA_VARS:
|
||||
found.add(match)
|
||||
return found
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description="Import detection overrides into the so-detection index.",
|
||||
)
|
||||
p.add_argument("--source", "-s", required=True,
|
||||
help="Source directory containing <publicId>.<ext> override files.")
|
||||
p.add_argument("--engine", "-e", default="suricata", choices=list(ENGINES.keys()),
|
||||
help="Detection engine (default: suricata). Sigma not yet supported.")
|
||||
p.add_argument("--dry-run", "-n", action="store_true",
|
||||
help="Print what would happen without writing to Elasticsearch.")
|
||||
p.add_argument("--no-import-note", action="store_true",
|
||||
help="Do not prepend '[Imported YYYY-MM-DD] ' to the override note.")
|
||||
p.add_argument("--index", "-i", default=DEFAULT_INDEX,
|
||||
help=f"Elasticsearch index to update (default: {DEFAULT_INDEX}).")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def confirm_proceed(args):
|
||||
"""Show the stale-backup warning. Dry-run prints it and continues. Real
|
||||
runs require the user typing 'yes' at the prompt."""
|
||||
print(STALE_WARNING)
|
||||
if args.dry_run:
|
||||
print("(dry-run: no acknowledgement required)\n")
|
||||
return True
|
||||
answer = input("Type 'yes' to acknowledge and continue: ").strip().lower()
|
||||
print()
|
||||
return answer == "yes"
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.engine == "sigma":
|
||||
print("ERROR: sigma overrides are not yet supported.", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
if not os.path.isdir(args.source):
|
||||
print(f"ERROR: source directory not found: {args.source}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
extension = ENGINES[args.engine]
|
||||
files = sorted(f for f in os.listdir(args.source) if f.endswith(f".{extension}"))
|
||||
if not files:
|
||||
print(f"No *.{extension} files found in {args.source}")
|
||||
sys.exit(0)
|
||||
|
||||
if not confirm_proceed(args):
|
||||
print("Aborted.")
|
||||
sys.exit(1)
|
||||
|
||||
session = make_session(AUTH_FILE)
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
note_prefix = "" if args.no_import_note else f"[Imported {today}] "
|
||||
|
||||
counts = {"added": 0, "skipped_dedupe": 0, "skipped_not_found": 0, "invalid": 0, "error": 0}
|
||||
custom_vars = set()
|
||||
|
||||
mode = "DRY-RUN" if args.dry_run else "IMPORT"
|
||||
print(f"[{mode}] engine={args.engine} source={args.source} index={args.index}\n")
|
||||
|
||||
for filename in files:
|
||||
public_id = os.path.splitext(filename)[0]
|
||||
path = os.path.join(args.source, filename)
|
||||
print(f"{public_id}:")
|
||||
|
||||
try:
|
||||
new_overrides = parse_overrides_file(path)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f" ERROR: could not parse {filename}: {e}")
|
||||
counts["error"] += 1
|
||||
continue
|
||||
|
||||
if not new_overrides:
|
||||
print(" SKIP: empty file")
|
||||
continue
|
||||
|
||||
try:
|
||||
doc_id, doc_index, existing = find_detection(session, args.index, public_id, args.engine)
|
||||
except requests.HTTPError as e:
|
||||
print(f" ERROR: search failed: {e}")
|
||||
counts["error"] += 1
|
||||
continue
|
||||
|
||||
if doc_id is None:
|
||||
print(f" WARN: no detection found for publicId={public_id} engine={args.engine}; skipping")
|
||||
counts["skipped_not_found"] += len(new_overrides)
|
||||
continue
|
||||
|
||||
existing_keys = {dedupe_key(o) for o in existing}
|
||||
merged = list(existing)
|
||||
added_this_file = 0
|
||||
|
||||
for override, line_no in new_overrides:
|
||||
err = validate_override(override, args.engine)
|
||||
if err:
|
||||
print(f" INVALID (line {line_no}): {err}")
|
||||
counts["invalid"] += 1
|
||||
continue
|
||||
|
||||
custom_vars.update(collect_custom_vars(override))
|
||||
key = dedupe_key(override)
|
||||
if key in existing_keys:
|
||||
print(f" SKIP (line {line_no}): duplicate of existing override [{describe(override)}]")
|
||||
counts["skipped_dedupe"] += 1
|
||||
continue
|
||||
|
||||
if note_prefix:
|
||||
override = dict(override)
|
||||
override["note"] = note_prefix + (override.get("note") or "")
|
||||
|
||||
merged.append(override)
|
||||
existing_keys.add(key)
|
||||
added_this_file += 1
|
||||
print(f" ADD (line {line_no}): {describe(override)}")
|
||||
|
||||
if added_this_file == 0:
|
||||
continue
|
||||
|
||||
if args.dry_run:
|
||||
print(f" DRY-RUN: would update {doc_index}/{doc_id} "
|
||||
f"({len(existing)} existing → {len(merged)} total)")
|
||||
counts["added"] += added_this_file
|
||||
continue
|
||||
|
||||
try:
|
||||
update_overrides(session, doc_index, doc_id, merged)
|
||||
print(f" UPDATED {doc_index}/{doc_id} ({len(existing)} → {len(merged)})")
|
||||
counts["added"] += added_this_file
|
||||
except requests.HTTPError as e:
|
||||
print(f" ERROR: update failed: {e}")
|
||||
counts["error"] += 1
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(f"Summary ({mode}):")
|
||||
print(f" Overrides added: {counts['added']}")
|
||||
print(f" Skipped (already present): {counts['skipped_dedupe']}")
|
||||
print(f" Skipped (no detection): {counts['skipped_not_found']}")
|
||||
print(f" Invalid (failed checks): {counts['invalid']}")
|
||||
print(f" Errors: {counts['error']}")
|
||||
|
||||
if custom_vars and args.engine == "suricata":
|
||||
print()
|
||||
print("WARNING: detected custom Suricata variables in imported overrides:")
|
||||
for v in sorted(custom_vars):
|
||||
print(f" {v}")
|
||||
print("If any of these are not already defined in SOC Config (Suricata variables),")
|
||||
print("you must add them manually before the rules will function correctly.")
|
||||
|
||||
sys.exit(0 if counts["error"] == 0 and counts["invalid"] == 0 else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,406 @@
|
||||
# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
|
||||
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at
|
||||
# https://securityonion.net/license; you may not use this file except in compliance with the
|
||||
# Elastic License 2.0.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# The script has no .py extension, so importlib.import_module won't find it by
|
||||
# name. SourceFileLoader loads source Python regardless of extension.
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
SCRIPT = os.path.join(HERE, "so-detections-overrides-import")
|
||||
soi = SourceFileLoader("so_overrides_import", SCRIPT).load_module()
|
||||
|
||||
|
||||
class TestValidateSuppress(unittest.TestCase):
|
||||
def test_valid(self):
|
||||
self.assertIsNone(soi.validate_override(
|
||||
{"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}, "suricata"))
|
||||
|
||||
def test_valid_var(self):
|
||||
self.assertIsNone(soi.validate_override(
|
||||
{"type": "suppress", "track": "by_either", "ip": "$HOME_NET"}, "suricata"))
|
||||
|
||||
def test_valid_cidr(self):
|
||||
self.assertIsNone(soi.validate_override(
|
||||
{"type": "suppress", "track": "by_dst", "ip": "10.0.0.0/8"}, "suricata"))
|
||||
|
||||
def test_valid_bracket_list(self):
|
||||
self.assertIsNone(soi.validate_override(
|
||||
{"type": "suppress", "track": "by_src", "ip": "[1.2.3.4,10.0.0.0/8]"}, "suricata"))
|
||||
|
||||
def test_missing_ip(self):
|
||||
err = soi.validate_override({"type": "suppress", "track": "by_src"}, "suricata")
|
||||
self.assertIn("requires", err)
|
||||
|
||||
def test_missing_track(self):
|
||||
err = soi.validate_override({"type": "suppress", "ip": "1.2.3.4"}, "suricata")
|
||||
self.assertIn("requires", err)
|
||||
|
||||
def test_invalid_track(self):
|
||||
err = soi.validate_override(
|
||||
{"type": "suppress", "track": "by_both", "ip": "1.2.3.4"}, "suricata")
|
||||
self.assertIn("invalid track", err)
|
||||
|
||||
def test_invalid_ip(self):
|
||||
err = soi.validate_override(
|
||||
{"type": "suppress", "track": "by_src", "ip": "not-an-ip"}, "suricata")
|
||||
self.assertIn("invalid IP", err)
|
||||
|
||||
def test_unnecessary_field(self):
|
||||
err = soi.validate_override(
|
||||
{"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "count": 5}, "suricata")
|
||||
self.assertIn("unnecessary fields", err)
|
||||
|
||||
|
||||
class TestValidateThreshold(unittest.TestCase):
|
||||
def test_valid(self):
|
||||
self.assertIsNone(soi.validate_override({
|
||||
"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "limit", "count": 10, "seconds": 60,
|
||||
}, "suricata"))
|
||||
|
||||
def test_valid_by_both(self):
|
||||
self.assertIsNone(soi.validate_override({
|
||||
"type": "threshold", "track": "by_both",
|
||||
"thresholdType": "both", "count": 1, "seconds": 1,
|
||||
}, "suricata"))
|
||||
|
||||
def test_track_by_either_invalid(self):
|
||||
err = soi.validate_override({
|
||||
"type": "threshold", "track": "by_either",
|
||||
"thresholdType": "limit", "count": 10, "seconds": 60,
|
||||
}, "suricata")
|
||||
self.assertIn("invalid track", err)
|
||||
|
||||
def test_invalid_threshold_type(self):
|
||||
err = soi.validate_override({
|
||||
"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "bogus", "count": 10, "seconds": 60,
|
||||
}, "suricata")
|
||||
self.assertIn("invalid thresholdType", err)
|
||||
|
||||
def test_zero_count(self):
|
||||
err = soi.validate_override({
|
||||
"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "limit", "count": 0, "seconds": 60,
|
||||
}, "suricata")
|
||||
self.assertIn("count", err)
|
||||
|
||||
def test_negative_seconds(self):
|
||||
err = soi.validate_override({
|
||||
"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "limit", "count": 10, "seconds": -1,
|
||||
}, "suricata")
|
||||
self.assertIn("seconds", err)
|
||||
|
||||
def test_missing_field(self):
|
||||
err = soi.validate_override({
|
||||
"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "limit", "count": 10, # missing seconds
|
||||
}, "suricata")
|
||||
self.assertIn("requires", err)
|
||||
|
||||
def test_unnecessary_field(self):
|
||||
err = soi.validate_override({
|
||||
"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "limit", "count": 10, "seconds": 60,
|
||||
"regex": "foo",
|
||||
}, "suricata")
|
||||
self.assertIn("unnecessary fields", err)
|
||||
|
||||
|
||||
class TestValidateModify(unittest.TestCase):
|
||||
def test_valid(self):
|
||||
self.assertIsNone(soi.validate_override(
|
||||
{"type": "modify", "regex": r"content:\"foo\"", "value": "content:bar"}, "suricata"))
|
||||
|
||||
def test_invalid_regex(self):
|
||||
err = soi.validate_override(
|
||||
{"type": "modify", "regex": "(unbalanced", "value": "x"}, "suricata")
|
||||
self.assertIn("invalid regex", err)
|
||||
|
||||
def test_missing_value(self):
|
||||
err = soi.validate_override({"type": "modify", "regex": "x"}, "suricata")
|
||||
self.assertIn("requires", err)
|
||||
|
||||
def test_unnecessary_field(self):
|
||||
err = soi.validate_override(
|
||||
{"type": "modify", "regex": "x", "value": "y", "track": "by_src"}, "suricata")
|
||||
self.assertIn("unnecessary fields", err)
|
||||
|
||||
|
||||
class TestValidateMisc(unittest.TestCase):
|
||||
def test_unknown_type(self):
|
||||
err = soi.validate_override({"type": "suppresss", "track": "by_src", "ip": "1.2.3.4"}, "suricata")
|
||||
self.assertIn("invalid type", err)
|
||||
|
||||
def test_missing_type(self):
|
||||
err = soi.validate_override({"track": "by_src"}, "suricata")
|
||||
self.assertIn("type is required", err)
|
||||
|
||||
def test_non_suricata_engine_skipped(self):
|
||||
# validate_override returns None for non-suricata engines (sigma is gated in main).
|
||||
self.assertIsNone(soi.validate_override({"type": "anything"}, "sigma"))
|
||||
|
||||
|
||||
class TestValidateIP(unittest.TestCase):
|
||||
def test_plain_ipv4(self):
|
||||
self.assertIsNone(soi._validate_suricata_ip("1.2.3.4"))
|
||||
|
||||
def test_plain_ipv6(self):
|
||||
self.assertIsNone(soi._validate_suricata_ip("::1"))
|
||||
|
||||
def test_cidr(self):
|
||||
self.assertIsNone(soi._validate_suricata_ip("10.0.0.0/8"))
|
||||
|
||||
def test_var(self):
|
||||
self.assertIsNone(soi._validate_suricata_ip("$CONCOURSEWORKERS"))
|
||||
|
||||
def test_bracket_list(self):
|
||||
self.assertIsNone(soi._validate_suricata_ip("[1.2.3.4, 10.0.0.0/8]"))
|
||||
|
||||
def test_bracket_list_bad_member(self):
|
||||
err = soi._validate_suricata_ip("[1.2.3.4,nope]")
|
||||
self.assertIn("invalid IP in list", err)
|
||||
|
||||
def test_empty(self):
|
||||
self.assertIn("empty", soi._validate_suricata_ip(""))
|
||||
|
||||
def test_invalid(self):
|
||||
self.assertIn("invalid", soi._validate_suricata_ip("999.999.999.999"))
|
||||
|
||||
|
||||
class TestDedupeKey(unittest.TestCase):
|
||||
def test_suppress(self):
|
||||
a = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4", "count": 99}
|
||||
b = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}
|
||||
# count is irrelevant for suppress dedupe
|
||||
self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b))
|
||||
|
||||
def test_suppress_differs_on_ip(self):
|
||||
a = {"type": "suppress", "track": "by_src", "ip": "1.2.3.4"}
|
||||
b = {"type": "suppress", "track": "by_src", "ip": "5.6.7.8"}
|
||||
self.assertNotEqual(soi.dedupe_key(a), soi.dedupe_key(b))
|
||||
|
||||
def test_threshold(self):
|
||||
a = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
|
||||
"count": 10, "seconds": 60, "ip": "ignored"}
|
||||
b = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
|
||||
"count": 10, "seconds": 60}
|
||||
self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b))
|
||||
|
||||
def test_threshold_differs_on_count(self):
|
||||
a = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
|
||||
"count": 10, "seconds": 60}
|
||||
b = {"type": "threshold", "track": "by_src", "thresholdType": "limit",
|
||||
"count": 20, "seconds": 60}
|
||||
self.assertNotEqual(soi.dedupe_key(a), soi.dedupe_key(b))
|
||||
|
||||
def test_modify(self):
|
||||
a = {"type": "modify", "regex": "x", "value": "y"}
|
||||
b = {"type": "modify", "regex": "x", "value": "y"}
|
||||
self.assertEqual(soi.dedupe_key(a), soi.dedupe_key(b))
|
||||
|
||||
|
||||
class TestDescribe(unittest.TestCase):
|
||||
def test_suppress(self):
|
||||
s = soi.describe({"type": "suppress", "track": "by_src", "ip": "1.2.3.4"})
|
||||
self.assertIn("suppress", s)
|
||||
self.assertIn("by_src", s)
|
||||
self.assertIn("1.2.3.4", s)
|
||||
|
||||
def test_threshold_includes_count(self):
|
||||
s = soi.describe({"type": "threshold", "track": "by_src",
|
||||
"thresholdType": "limit", "count": 10, "seconds": 60})
|
||||
self.assertIn("count=10", s)
|
||||
self.assertIn("seconds=60", s)
|
||||
|
||||
def test_modify(self):
|
||||
s = soi.describe({"type": "modify", "regex": "foo"})
|
||||
self.assertIn("modify", s)
|
||||
self.assertIn("foo", s)
|
||||
|
||||
|
||||
class TestParseOverridesFile(unittest.TestCase):
|
||||
def _write(self, content):
|
||||
fd, path = tempfile.mkstemp(suffix=".txt")
|
||||
os.close(fd)
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
self.addCleanup(os.unlink, path)
|
||||
return path
|
||||
|
||||
def test_single_line(self):
|
||||
path = self._write('{"type":"suppress","track":"by_src","ip":"1.2.3.4"}')
|
||||
result = soi.parse_overrides_file(path)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0][0]["type"], "suppress")
|
||||
self.assertEqual(result[0][1], 1)
|
||||
|
||||
def test_ndjson(self):
|
||||
path = self._write(
|
||||
'{"type":"suppress","track":"by_src","ip":"1.2.3.4"}\n'
|
||||
'{"type":"suppress","track":"by_dst","ip":"5.6.7.8"}\n'
|
||||
)
|
||||
result = soi.parse_overrides_file(path)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[1][1], 2)
|
||||
|
||||
def test_empty(self):
|
||||
path = self._write("")
|
||||
self.assertEqual(soi.parse_overrides_file(path), [])
|
||||
|
||||
def test_blank_lines_skipped(self):
|
||||
path = self._write('\n{"type":"suppress","track":"by_src","ip":"1.2.3.4"}\n\n')
|
||||
result = soi.parse_overrides_file(path)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0][1], 2) # line number reflects original position
|
||||
|
||||
def test_invalid_raises(self):
|
||||
path = self._write("not json")
|
||||
with self.assertRaises(json.JSONDecodeError):
|
||||
soi.parse_overrides_file(path)
|
||||
|
||||
|
||||
class TestCollectCustomVars(unittest.TestCase):
|
||||
def test_finds_custom(self):
|
||||
v = soi.collect_custom_vars({"ip": "$CONCOURSEWORKERS"})
|
||||
self.assertEqual(v, {"$CONCOURSEWORKERS"})
|
||||
|
||||
def test_filters_builtins(self):
|
||||
v = soi.collect_custom_vars({"ip": "$HOME_NET"})
|
||||
self.assertEqual(v, set())
|
||||
|
||||
def test_mixed(self):
|
||||
v = soi.collect_custom_vars({"ip": "[$HOME_NET,$MYNET]"})
|
||||
self.assertEqual(v, {"$MYNET"})
|
||||
|
||||
def test_non_string_fields_ignored(self):
|
||||
v = soi.collect_custom_vars({"count": 10, "isEnabled": True})
|
||||
self.assertEqual(v, set())
|
||||
|
||||
|
||||
class TestMakeSession(unittest.TestCase):
|
||||
def _write(self, content):
|
||||
fd, path = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
self.addCleanup(os.unlink, path)
|
||||
return path
|
||||
|
||||
def test_valid_auth_file(self):
|
||||
path = self._write('user = "admin:secret"\n')
|
||||
session = soi.make_session(path)
|
||||
self.assertEqual(session.auth.username, "admin")
|
||||
self.assertEqual(session.auth.password, "secret")
|
||||
self.assertFalse(session.verify)
|
||||
|
||||
def test_missing_user_line(self):
|
||||
path = self._write("# no user line here\n")
|
||||
with self.assertRaises(RuntimeError):
|
||||
soi.make_session(path)
|
||||
|
||||
|
||||
class TestFindDetection(unittest.TestCase):
|
||||
def _session_with_response(self, payload):
|
||||
session = MagicMock()
|
||||
response = MagicMock()
|
||||
response.json.return_value = payload
|
||||
response.raise_for_status.return_value = None
|
||||
session.get.return_value = response
|
||||
return session
|
||||
|
||||
def test_found(self):
|
||||
session = self._session_with_response({"hits": {"hits": [{
|
||||
"_id": "abc", "_index": "so-detection",
|
||||
"_source": {"so_detection": {"overrides": [{"type": "suppress"}]}},
|
||||
}]}})
|
||||
doc_id, idx, existing = soi.find_detection(session, "so-detection", "2049201", "suricata")
|
||||
self.assertEqual(doc_id, "abc")
|
||||
self.assertEqual(idx, "so-detection")
|
||||
self.assertEqual(len(existing), 1)
|
||||
|
||||
def test_not_found(self):
|
||||
session = self._session_with_response({"hits": {"hits": []}})
|
||||
doc_id, idx, existing = soi.find_detection(session, "so-detection", "x", "suricata")
|
||||
self.assertIsNone(doc_id)
|
||||
self.assertIsNone(idx)
|
||||
self.assertIsNone(existing)
|
||||
|
||||
def test_no_overrides_field(self):
|
||||
session = self._session_with_response({"hits": {"hits": [{
|
||||
"_id": "abc", "_index": "so-detection",
|
||||
"_source": {"so_detection": {}},
|
||||
}]}})
|
||||
_, _, existing = soi.find_detection(session, "so-detection", "x", "suricata")
|
||||
self.assertEqual(existing, [])
|
||||
|
||||
def test_multiple_hits_warns(self):
|
||||
session = self._session_with_response({"hits": {"hits": [
|
||||
{"_id": "a", "_index": "i", "_source": {"so_detection": {"overrides": []}}},
|
||||
{"_id": "b", "_index": "i", "_source": {"so_detection": {"overrides": []}}},
|
||||
]}})
|
||||
with patch("sys.stdout", new=StringIO()) as out:
|
||||
doc_id, _, _ = soi.find_detection(session, "i", "x", "suricata")
|
||||
self.assertEqual(doc_id, "a")
|
||||
self.assertIn("WARN", out.getvalue())
|
||||
|
||||
|
||||
class TestUpdateOverrides(unittest.TestCase):
|
||||
def test_posts_to_update_endpoint(self):
|
||||
session = MagicMock()
|
||||
response = MagicMock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.json.return_value = {"result": "updated"}
|
||||
session.post.return_value = response
|
||||
|
||||
result = soi.update_overrides(session, "so-detection", "abc", [{"type": "suppress"}])
|
||||
|
||||
self.assertEqual(result, {"result": "updated"})
|
||||
url = session.post.call_args[0][0]
|
||||
self.assertIn("/_update/abc", url)
|
||||
body = session.post.call_args[1]["json"]
|
||||
self.assertEqual(body["doc"]["so_detection"]["overrides"], [{"type": "suppress"}])
|
||||
|
||||
|
||||
class TestConfirmProceed(unittest.TestCase):
|
||||
def test_dry_run_skips_prompt(self):
|
||||
args = MagicMock(dry_run=True)
|
||||
with patch("sys.stdout", new=StringIO()):
|
||||
self.assertTrue(soi.confirm_proceed(args))
|
||||
|
||||
def test_yes_input(self):
|
||||
args = MagicMock(dry_run=False)
|
||||
with patch("sys.stdout", new=StringIO()):
|
||||
with patch("builtins.input", return_value="yes"):
|
||||
self.assertTrue(soi.confirm_proceed(args))
|
||||
|
||||
def test_yes_input_case_insensitive(self):
|
||||
args = MagicMock(dry_run=False)
|
||||
with patch("sys.stdout", new=StringIO()):
|
||||
with patch("builtins.input", return_value="YES"):
|
||||
self.assertTrue(soi.confirm_proceed(args))
|
||||
|
||||
def test_no_input_aborts(self):
|
||||
args = MagicMock(dry_run=False)
|
||||
with patch("sys.stdout", new=StringIO()):
|
||||
with patch("builtins.input", return_value="no"):
|
||||
self.assertFalse(soi.confirm_proceed(args))
|
||||
|
||||
def test_empty_input_aborts(self):
|
||||
args = MagicMock(dry_run=False)
|
||||
with patch("sys.stdout", new=StringIO()):
|
||||
with patch("builtins.input", return_value=""):
|
||||
self.assertFalse(soi.confirm_proceed(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user