#!/usr/bin/env python3

# Copyright 2014-2023 Security Onion Solutions, LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from itertools import chain
from typing import List

import signal
import sys
import os
import re
import subprocess
import argparse
import textwrap
import yaml
import multiprocessing
import docker
import pty

minion_pillar_dir = '/opt/so/saltstack/local/pillar/minions'
so_status_conf = '/opt/so/conf/so-status/so-status.conf'
proc: subprocess.CompletedProcess = None

# Temp store of modules, will likely be broken out into salt
def get_learn_modules():
  return {
    'logscan': { 'cpu_period': get_cpu_period(fraction=0.25), 'enabled': False, 'description': 'Scan log files against pre-trained models to alert on anomalies.' }
  }


def get_cpu_period(fraction: float):
  multiplier = 10000
  
  num_cores = multiprocessing.cpu_count()
  if num_cores <= 2:
    fraction = 1.
  
  num_used_cores = int(num_cores * fraction)
  cpu_period = num_used_cores * multiplier
  return cpu_period


def sigint_handler(*_):
  print('Exiting gracefully on Ctrl-C')
  if proc is not None: proc.send_signal(signal.SIGINT)
  sys.exit(1)


def find_minion_pillar() -> str:
  regex = '^.*_(manager|managersearch|standalone|import|eval)\.sls$'
  
  result = []
  for root, _, files in os.walk(minion_pillar_dir):
    for f_minion_id in files:
      if re.search(regex, f_minion_id):
        result.append(os.path.join(root, f_minion_id))
  
  if len(result) == 0:
    print('Could not find manager-type pillar (eval, standalone, manager, managersearch, import). Are you running this script on the manager?', file=sys.stderr)
    sys.exit(3)
  elif len(result) > 1:
    res_str = ', '.join(f'\"{result}\"')
    print('(This should not happen, the system is in an error state if you see this message.)\n', file=sys.stderr)
    print('More than one manager-type pillar exists, minion id\'s listed below:', file=sys.stderr)
    print(f'  {res_str}', file=sys.stderr)
    sys.exit(3)
  else:
    return result[0]


def read_pillar(pillar: str):
  try:
    with open(pillar, 'r') as pillar_file:
      loaded_yaml = yaml.safe_load(pillar_file.read())
      if loaded_yaml is None:
        print(f'Could not parse {pillar}', file=sys.stderr)
        sys.exit(3)
      return loaded_yaml
  except:
    print(f'Could not open {pillar}', file=sys.stderr)
    sys.exit(3)


def write_pillar(pillar: str, content: dict):
  try:
    with open(pillar, 'w') as pillar_file:
      yaml.dump(content, pillar_file, default_flow_style=False)
  except:
    print(f'Could not open {pillar}', file=sys.stderr)
    sys.exit(3)


def mod_so_status(action: str, item: str):
  with open(so_status_conf, 'a+') as conf:
    conf.seek(0)
    containers = conf.readlines()
    
    if f'so-{item}\n' in containers:
      if action == 'remove': containers.remove(f'so-{item}\n')
      if action == 'add': pass
    else:
      if action == 'remove': pass
      if action == 'add': containers.append(f'so-{item}\n')
      
    [containers.remove(c_name) for c_name in containers if c_name == '\n'] # remove extra newlines
    
    conf.seek(0)
    conf.truncate(0)
    conf.writelines(containers)


def create_pillar_if_not_exist(pillar:str, content: dict):
  pillar_dict = content

  if pillar_dict.get('learn', {}).get('modules') is None:
    pillar_dict['learn'] = {}
    pillar_dict['learn']['modules'] = get_learn_modules()
    content.update()
    write_pillar(pillar, content)
  
  return content


def salt_call(module: str):
  salt_cmd = ['salt-call', 'state.apply', '-l', 'quiet', f'learn.{module}', 'queue=True']

  print(f'  Applying salt state for {module} module...')
  proc = subprocess.run(salt_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
  return_code = proc.returncode
  if return_code != 0:
    print(f'  [ERROR] Failed to apply salt state for {module} module.')

  return return_code


def pull_image(module: str):
  container_basename = f'so-{module}'

  client = docker.from_env()
  image_list = client.images.list(filters={ 'dangling': False })  
  tag_list = list(chain.from_iterable(list(map(lambda x: x.attrs.get('RepoTags'), image_list))))
  basename_match = list(filter(lambda x: f'{container_basename}' in x, tag_list))
  local_registry_match = list(filter(lambda x: ':5000' in x, basename_match))

  if len(local_registry_match) == 0:
    print(f'Pulling and verifying missing image for {module} (may take several minutes) ...')
    pull_command = ['so-image-pull', '--quiet', container_basename]
  
    proc = subprocess.run(pull_command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    return_code = proc.returncode
    if return_code != 0:
      print(f'[ERROR] Failed to pull image so-{module}, skipping state.')
  else:
    return_code = 0
  return return_code


def apply(module_list: List):
  return_code = 0
  for module in module_list:
    salt_ret = salt_call(module)
    # Only update return_code if the command returned a non-zero return
    if salt_ret != 0:
      return_code = salt_ret
    
  return return_code


def check_apply(args: dict):
  if args.apply:
    print('Configuration updated. Applying changes:')
    return apply(args.modules)
  else:
    message = 'Configuration updated. Would you like to apply your changes now? (y/N) '
    answer = input(message)
    while answer.lower() not in [ 'y', 'n', '' ]:
      answer = input(message)
    if answer.lower() in [ 'n', '' ]:
      return 0
    else:
      print('Applying changes:')
      return apply(args.modules)


def enable_disable_modules(args, enable: bool):
  pillar_modules = args.pillar_dict.get('learn', {}).get('modules')
  pillar_mod_names = args.pillar_dict.get('learn', {}).get('modules').keys()
  
  action_str = 'add' if enable else 'remove'

  if 'all' in args.modules:
    for module, details in pillar_modules.items():
      details['enabled'] = enable
      mod_so_status(action_str, module)
      if enable: pull_image(module)
    args.pillar_dict.update()
    write_pillar(args.pillar, args.pillar_dict)
  else:
    write_needed = False
    for module in args.modules:
      if module in pillar_mod_names:
        if pillar_modules[module]['enabled'] == enable:
          state_str = 'enabled' if enable else 'disabled'
          print(f'{module} module already {state_str}.', file=sys.stderr)
        else:
          if enable and pull_image(module) != 0:
            continue
          pillar_modules[module]['enabled'] = enable
          mod_so_status(action_str, module)
          write_needed = True
    if write_needed:
      args.pillar_dict.update()
      write_pillar(args.pillar, args.pillar_dict)
  
  cmd_ret = check_apply(args)
  return cmd_ret


def enable_modules(args):
  enable_disable_modules(args, enable=True)


def disable_modules(args):
  enable_disable_modules(args, enable=False)


def list_modules(*_):
  print('Available ML modules:')
  for module, details in get_learn_modules().items():
    print(f'  - { module } : {details["description"]}')
  return 0


def main():
  beta_str = 'BETA - SUBJECT TO CHANGE\n'

  apply_help='After ACTION the chosen modules, apply any necessary salt states.'
  enable_apply_help = apply_help.replace('ACTION', 'enabling')
  disable_apply_help = apply_help.replace('ACTION', 'disabling')

  signal.signal(signal.SIGINT, sigint_handler)

  if os.geteuid() != 0:
    print('You must run this script as root', file=sys.stderr)
    sys.exit(1)

  main_parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)

  subcommand_desc = textwrap.dedent(
    """\
      enable              Enable one or more ML modules.
      disable             Disable one or more ML modules.
      list                List all available ML modules.
    """
  )

  subparsers = main_parser.add_subparsers(title='commands', description=subcommand_desc, metavar='', dest='command')

  module_help_str = 'One or more ML modules, which can be listed using \'so-learn list\'. Use the keyword \'all\' to apply the action to all available modules.'

  enable = subparsers.add_parser('enable')
  enable.set_defaults(func=enable_modules)
  enable.add_argument('modules', metavar='ML_MODULE', nargs='+', help=module_help_str)
  enable.add_argument('--apply', action='store_const', const=True, required=False, help=enable_apply_help)

  disable = subparsers.add_parser('disable')
  disable.set_defaults(func=disable_modules)
  disable.add_argument('modules', metavar='ML_MODULE', nargs='+', help=module_help_str)
  disable.add_argument('--apply', action='store_const', const=True, required=False, help=disable_apply_help)

  list = subparsers.add_parser('list')
  list.set_defaults(func=list_modules)

  args = main_parser.parse_args(sys.argv[1:])
  args.pillar = find_minion_pillar()
  args.pillar_dict = create_pillar_if_not_exist(args.pillar, read_pillar(args.pillar))

  if hasattr(args, 'func'):
    exit_code = args.func(args)
  else:
    if args.command is None:
      print(beta_str)
      main_parser.print_help()
    sys.exit(0)

  sys.exit(exit_code)


if __name__ == '__main__':
  main()
