#!/usr/bin/env python3

# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at 
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.



""" 
Local exit codes:
 - General error: 1
 - Invalid argument: 2
 - File error: 3
"""

import sys, os, subprocess, argparse, signal
import copy
import re
import textwrap
import yaml

minion_pillar_dir = '/opt/so/saltstack/local/pillar/minions'
salt_proc: subprocess.CompletedProcess = None


def print_err(string: str):
  print(string, file=sys.stderr)


def check_apply(args: dict, prompt: bool = True):
  if args.apply:
    print('Configuration updated. Applying changes:')
    return apply()
  else:
    if prompt:
      message = 'Configuration updated. Would you like to apply your changes now? (y/N) '
      answer = input(message)
      while answer.lower() not in [ 'y', 'n', '' ]:
        answer = input(message)
      if answer.lower() in [ 'n', '' ]:
        return 0
      else:
        print('Applying changes:')
        return apply()
    else:
      return 0


def apply():
  salt_cmd = ['salt-call', 'state.apply', '-l', 'quiet', 'idstools.sync_files', 'queue=True']
  update_cmd = ['so-rule-update']
  print('Syncing config files...')
  cmd = subprocess.run(salt_cmd, stdout=subprocess.DEVNULL)
  if cmd.returncode == 0:
    print('Updating rules...')
    return subprocess.run(update_cmd).returncode
  else:
    return cmd.returncode


def find_minion_pillar() -> str:
  regex = '^.*_(manager|managersearch|standalone|import|eval)\.sls$'
  
  result = []
  for root, _, files in os.walk(minion_pillar_dir):
    for f_minion_id in files:
      if re.search(regex, f_minion_id):
        result.append(os.path.join(root, f_minion_id))
  
  if len(result) == 0:
    print_err('Could not find manager-type pillar (eval, standalone, manager, managersearch, import). Are you running this script on the manager?')
    sys.exit(3)
  elif len(result) > 1:
    res_arr = []
    for r in result:
      res_arr.append(f'\"{r}\"')
    res_str = ', '.join(res_arr)
    print_err('(This should not happen, the system is in an error state if you see this message.)\n')
    print_err('More than one manager-type pillar exists, minion id\'s listed below:')
    print_err(f'  {res_str}')
    sys.exit(3)
  else:
    return result[0]


def read_pillar(pillar: str):
  try:
    with open(pillar, 'r') as f:
      loaded_yaml = yaml.safe_load(f.read())
      if loaded_yaml is None:
        print_err(f'Could not parse {pillar}')
        sys.exit(3)
      return loaded_yaml
  except:
    print_err(f'Could not open {pillar}')
    sys.exit(3)


def write_pillar(pillar: str, content: dict):
  try:
    sids = content['idstools']['sids']
    if sids['disabled'] is not None:
      if len(sids['disabled']) == 0: sids['disabled'] = None
    if sids['enabled'] is not None:
      if len(sids['enabled']) == 0: sids['enabled'] = None
    if sids['modify'] is not None:
      if len(sids['modify']) == 0: sids['modify'] = None

    with open(pillar, 'w') as f:
      return yaml.dump(content, f, default_flow_style=False)
  except Exception as e:
    print_err(f'Could not open {pillar}')
    sys.exit(3)


def check_sid_pattern(sid_pattern: str):
  message = f'SID {sid_pattern} is not valid, did you forget the \"re:\" prefix for a regex pattern?'

  if sid_pattern.startswith('re:'):
    r_string = sid_pattern[3:]
    if not valid_regex(r_string):
      print_err('Invalid regex pattern.')
      return False
    else:
      return True
  else:
    sid: int
    try:
      sid = int(sid_pattern)
    except:
      print_err(message)
      return False

    if sid >= 0:
      return True
    else:
      print_err(message)
      return False


def valid_regex(pattern: str):
  try:
    re.compile(pattern)
    return True
  except re.error:
    return False


def sids_key_exists(pillar: dict, key: str):
  return key in pillar.get('idstools', {}).get('sids', {})


def rem_from_sids(pillar: dict, key: str, val: str, optional = False):
  pillar_dict = copy.deepcopy(pillar)
  arr = pillar_dict['idstools']['sids'][key]
  if arr is None or val not in arr:
    if not optional: print(f'{val} already does not exist in {key}')
  else:
    pillar_dict['idstools']['sids'][key].remove(val)
  return pillar_dict


def add_to_sids(pillar: dict, key: str, val: str, optional = False):
  pillar_dict = copy.deepcopy(pillar)
  if pillar_dict['idstools']['sids'][key] is None: 
    pillar_dict['idstools']['sids'][key] = []
  if val in pillar_dict['idstools']['sids'][key]:
    if not optional: print(f'{val} already exists in {key}')
  else:
    pillar_dict['idstools']['sids'][key].append(val)
  return pillar_dict


def add_rem_disabled(args: dict):
  global salt_proc

  if not check_sid_pattern(args.sid_pattern):
    return 2

  pillar_dict = read_pillar(args.pillar)

  if not sids_key_exists(pillar_dict, 'disabled'):
    pillar_dict['idstools']['sids']['disabled'] = None
  
  if args.remove:
    temp_pillar_dict = rem_from_sids(pillar_dict, 'disabled', args.sid_pattern)
  else:
    temp_pillar_dict = add_to_sids(pillar_dict, 'disabled', args.sid_pattern)

  if temp_pillar_dict['idstools']['sids']['disabled'] == pillar_dict['idstools']['sids']['disabled']:
    salt_proc = check_apply(args, prompt=False)
    return salt_proc
  else:
    pillar_dict = temp_pillar_dict

  if not args.remove:
    if sids_key_exists(pillar_dict, 'enabled'):
      pillar_dict = rem_from_sids(pillar_dict, 'enabled', args.sid_pattern, optional=True)

    modify = pillar_dict.get('idstools', {}).get('sids', {}).get('modify')
    if modify is not None:
      rem_candidates = []
      for action in modify:
        if action.startswith(f'{args.sid_pattern} '):
          rem_candidates.append(action)
      if len(rem_candidates) > 0:
        for item in rem_candidates:
          print(f' - {item}')
        answer = input(f'The above modify actions contain {args.sid_pattern}. Would you like to remove them? (Y/n) ')
        while answer.lower() not in [ 'y', 'n', '' ]:
          for item in rem_candidates:
            print(f' - {item}')
          answer = input(f'The above modify actions contain {args.sid_pattern}. Would you like to remove them? (Y/n) ')
        if answer.lower() in [ 'y', '' ]:
          for item in rem_candidates:
            modify.remove(item)
          pillar_dict['idstools']['sids']['modify'] = modify

  write_pillar(pillar=args.pillar, content=pillar_dict)
  
  salt_proc = check_apply(args)
  return salt_proc


def list_disabled_rules(args: dict):
  pillar_dict = read_pillar(args.pillar)

  disabled = pillar_dict.get('idstools', {}).get('sids', {}).get('disabled')
  if disabled is None:
    print('No rules disabled.')
    return 0
  else:
    print('Disabled rules:')
    for rule in disabled:
      print(f'  - {rule}')
    return 0


def add_rem_enabled(args: dict):
  global salt_proc

  if not check_sid_pattern(args.sid_pattern):
    return 2

  pillar_dict = read_pillar(args.pillar)

  if not sids_key_exists(pillar_dict, 'enabled'):
    pillar_dict['idstools']['sids']['enabled'] = None

  if args.remove:
    temp_pillar_dict = rem_from_sids(pillar_dict, 'enabled', args.sid_pattern)
  else:
    temp_pillar_dict = add_to_sids(pillar_dict, 'enabled', args.sid_pattern)
  
  if temp_pillar_dict['idstools']['sids']['enabled'] == pillar_dict['idstools']['sids']['enabled']:
    salt_proc = check_apply(args, prompt=False)
    return salt_proc
  else:
    pillar_dict = temp_pillar_dict

  if not args.remove:
    if sids_key_exists(pillar_dict, 'disabled'):
      pillar_dict = rem_from_sids(pillar_dict, 'disabled', args.sid_pattern, optional=True)
  
  write_pillar(pillar=args.pillar, content=pillar_dict)

  salt_proc = check_apply(args)
  return salt_proc


def list_enabled_rules(args: dict):
  pillar_dict = read_pillar(args.pillar)

  enabled = pillar_dict.get('idstools', {}).get('sids', {}).get('enabled')
  if enabled is None:
    print('No rules explicitly enabled.')
    return 0
  else:
    print('Enabled rules:')
    for rule in enabled:
      print(f'  - {rule}')
    return 0


def add_rem_modify(args: dict):
  global salt_proc

  if not check_sid_pattern(args.sid_pattern):
    return 2

  if not valid_regex(args.search_term):
    print_err('Search term is not a valid regex pattern.')

  string_val = f'{args.sid_pattern} \"{args.search_term}\" \"{args.replace_term}\"'

  pillar_dict = read_pillar(args.pillar)

  if not sids_key_exists(pillar_dict, 'modify'):
    pillar_dict['idstools']['sids']['modify'] = None

  if args.remove:
    temp_pillar_dict = rem_from_sids(pillar_dict, 'modify', string_val)
  else:
    temp_pillar_dict = add_to_sids(pillar_dict, 'modify', string_val)
    
  if temp_pillar_dict['idstools']['sids']['modify'] == pillar_dict['idstools']['sids']['modify']:
    salt_proc = check_apply(args, prompt=False)
    return salt_proc
  else:
    pillar_dict = temp_pillar_dict

  # TODO: Determine if a rule should be removed from disabled if modified.
  if not args.remove:
    if sids_key_exists(pillar_dict, 'disabled'):
      pillar_dict = rem_from_sids(pillar_dict, 'disabled', args.sid_pattern, optional=True) 

  write_pillar(pillar=args.pillar, content=pillar_dict)
  
  salt_proc = check_apply(args)
  return salt_proc
  

def list_modified_rules(args: dict):
  pillar_dict = read_pillar(args.pillar)

  modify = pillar_dict.get('idstools', {}).get('sids', {}).get('modify')
  if modify is None:
    print('No rules currently modified.')
    return 0
  else:
    print('Modified rules + modifications:')
    for rule in modify:
      print(f'  - {rule}')
    return 0


def sigint_handler(*_):
  print('Exiting gracefully on Ctrl-C')
  if salt_proc is not None: salt_proc.send_signal(signal.SIGINT)
  sys.exit(0)


def main():
  signal.signal(signal.SIGINT, sigint_handler)

  if os.geteuid() != 0:
    print_err('You must run this script as root')
    sys.exit(1)

  apply_help='After updating rule configuration, apply the idstools state.'

  main_parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)

  subcommand_desc = textwrap.dedent(
    """\
      disabled            Manage and list disabled rules (add, remove, list)
      enabled             Manage and list enabled rules (add, remove, list)
      modify              Manage and list modified rules (add, remove, list)
    """
  )
  subparsers = main_parser.add_subparsers(title='commands', description=subcommand_desc, metavar='', dest='command')


  sid_or_regex_help = 'A valid SID (ex: "4321") or regular expression pattern (ex: "re:heartbleed|spectre")'

  # Disabled actions
  disabled = subparsers.add_parser('disabled')
  disabled_sub = disabled.add_subparsers()

  disabled_add = disabled_sub.add_parser('add')
  disabled_add.set_defaults(func=add_rem_disabled)
  disabled_add.add_argument('sid_pattern', metavar='SID|REGEX', help=sid_or_regex_help)
  disabled_add.add_argument('--apply', action='store_const', const=True, required=False, help=apply_help)

  disabled_rem = disabled_sub.add_parser('remove')
  disabled_rem.set_defaults(func=add_rem_disabled, remove=True)
  disabled_rem.add_argument('sid_pattern', metavar='SID|REGEX', help=sid_or_regex_help)
  disabled_rem.add_argument('--apply', action='store_const', const=True, required=False, help=apply_help)

  disabled_list = disabled_sub.add_parser('list')
  disabled_list.set_defaults(func=list_disabled_rules)


  # Enabled actions
  enabled = subparsers.add_parser('enabled')
  enabled_sub = enabled.add_subparsers()

  enabled_add = enabled_sub.add_parser('add')
  enabled_add.set_defaults(func=add_rem_enabled)
  enabled_add.add_argument('sid_pattern', metavar='SID|REGEX', help=sid_or_regex_help)
  enabled_add.add_argument('--apply', action='store_const', const=True, required=False, help=apply_help)

  enabled_rem = enabled_sub.add_parser('remove')
  enabled_rem.set_defaults(func=add_rem_enabled, remove=True)
  enabled_rem.add_argument('sid_pattern', metavar='SID|REGEX', help=sid_or_regex_help)
  enabled_rem.add_argument('--apply', action='store_const', const=True, required=False, help=apply_help)

  enabled_list = enabled_sub.add_parser('list')
  enabled_list.set_defaults(func=list_enabled_rules)
  

  search_term_help='A properly escaped regex search term (ex: "\\\$EXTERNAL_NET")'
  replace_term_help='The text to replace the search term with'

  # Modify actions
  modify = subparsers.add_parser('modify')
  modify_sub = modify.add_subparsers()

  modify_add = modify_sub.add_parser('add')
  modify_add.set_defaults(func=add_rem_modify)
  modify_add.add_argument('sid_pattern', metavar='SID|REGEX', help=sid_or_regex_help)
  modify_add.add_argument('search_term', metavar='SEARCH_TERM', help=search_term_help)
  modify_add.add_argument('replace_term', metavar='REPLACE_TERM', help=replace_term_help)
  modify_add.add_argument('--apply', action='store_const', const=True, required=False, help=apply_help)

  modify_rem = modify_sub.add_parser('remove')
  modify_rem.set_defaults(func=add_rem_modify, remove=True)
  modify_rem.add_argument('sid_pattern', metavar='SID', help=sid_or_regex_help)
  modify_rem.add_argument('search_term', metavar='SEARCH_TERM', help=search_term_help)
  modify_rem.add_argument('replace_term', metavar='REPLACE_TERM', help=replace_term_help)
  modify_rem.add_argument('--apply', action='store_const', const=True, required=False, help=apply_help)

  modify_list = modify_sub.add_parser('list')
  modify_list.set_defaults(func=list_modified_rules)


  # Begin parse + run
  args = main_parser.parse_args(sys.argv[1:])
  
  if not hasattr(args, 'remove'):
    args.remove = False

  args.pillar = find_minion_pillar()

  if hasattr(args, 'func'):
    exit_code = args.func(args)
  else:
    if args.command is None:
      main_parser.print_help()
    else:
      if args.command == 'disabled':
        disabled.print_help()
      elif args.command == 'enabled':
        enabled.print_help()
      elif args.command == 'modify':
        modify.print_help()
    sys.exit(0)
  
  sys.exit(exit_code)


if __name__ == '__main__':
  main()
