#!/usr/bin/env python3

# Copyright 2014-2023 Security Onion Solutions, LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import ipaddress
import textwrap
import os
import subprocess
import sys
import argparse
import re
from lxml import etree as ET
from xml.dom import minidom


LOCAL_SALT_DIR='/opt/so/saltstack/local'
WAZUH_CONF='/nsm/wazuh/etc/ossec.conf'
VALID_ROLES = {
  'a': { 'role': 'analyst','desc': 'Analyst - 80/tcp, 443/tcp' },
  'b': { 'role': 'beats_endpoint', 'desc': 'Logstash Beat - 5044/tcp' },
  'e': { 'role': 'elasticsearch_rest', 'desc': 'Elasticsearch REST API - 9200/tcp' },
  'f': { 'role': 'strelka_frontend', 'desc': 'Strelka frontend - 57314/tcp' },
  'o': { 'role': 'osquery_endpoint', 'desc': 'Osquery endpoint - 8090/tcp' },
  's': { 'role': 'syslog', 'desc': 'Syslog device - 514/tcp/udp' },
  'w': { 'role': 'wazuh_agent', 'desc': 'Wazuh agent - 1514/tcp/udp' },
  'p': { 'role': 'wazuh_api', 'desc': 'Wazuh API - 55000/tcp' },
  'r': { 'role': 'wazuh_authd', 'desc': 'Wazuh registration service - 1515/tcp' }
}


def validate_ip_cidr(ip_cidr: str) -> bool:
  try:
    ipaddress.ip_address(ip_cidr)
  except ValueError:
    try:
      ipaddress.ip_network(ip_cidr)
    except ValueError:
      return False
  return True


def role_prompt() -> str:
  print()
  print('Choose the role for the IP or Range you would like to deny')
  print()
  for role in VALID_ROLES:
    print(f'[{role}] - {VALID_ROLES[role]["desc"]}')
  print()
  role = input('Please enter your selection: ')
  if role in VALID_ROLES.keys():
    return VALID_ROLES[role]['role']
  else:
    print(f'Invalid role \'{role}\', please try again.', file=sys.stderr)
    sys.exit(1)
  

def ip_prompt() -> str:
  ip = input('Enter a single ip address or range to deny (ex: 10.10.10.10 or 10.10.0.0/16): ')
  if validate_ip_cidr(ip):
    return ip
  else:
    print(f'Invalid IP address or CIDR block \'{ip}\', please try again.', file=sys.stderr)
    sys.exit(1)


def wazuh_enabled() -> bool:
  for file in os.listdir(f'{LOCAL_SALT_DIR}/pillar'):
    with open(file, 'r') as pillar:
      if 'wazuh: 1' in pillar.read():
        return True
  return False


def root_to_str(root: ET.ElementTree) -> str:
    xml_str = ET.tostring(root, encoding='unicode', method='xml').replace('\n', '')
    xml_str = re.sub(r'(?:(?<=>) *)', '', xml_str)

    # Remove specific substrings to better format comments on intial parse/write
    xml_str = re.sub(r'       -', '', xml_str)
    xml_str = re.sub(r'      -->', ' -->', xml_str)
    
    dom = minidom.parseString(xml_str)
    return dom.toprettyxml(indent="  ")


def rem_wl(ip):
    parser = ET.XMLParser(remove_blank_text=True)
    with open(WAZUH_CONF, 'rb') as wazuh_conf:
        tree = ET.parse(wazuh_conf, parser)
    root = tree.getroot()

    global_elems = root.findall(f"global/white_list[. = '{ip}']/..")
    if len(global_elems) > 0:
      for g_elem in global_elems:
          ge_index = list(root).index(g_elem)
          if ge_index > 0 and root[list(root).index(g_elem) - 1].tag == ET.Comment:
              root.remove(root[ge_index - 1])
          root.remove(g_elem)

      with open(WAZUH_CONF, 'w') as out:
          out.write(root_to_str(root))


def apply(role: str, ip: str) -> int:
  firewall_cmd = ['so-firewall', 'excludehost', role, ip]
  salt_cmd = ['salt-call', 'state.apply', '-l', 'quiet', 'firewall', 'queue=True']
  restart_wazuh_cmd = ['so-wazuh-restart']
  print(f'Removing {ip} from the {role} role. This can take a few seconds...')
  cmd = subprocess.run(firewall_cmd)
  if cmd.returncode == 0:
    cmd = subprocess.run(salt_cmd, stdout=subprocess.DEVNULL)
  else:
    return cmd.returncode
  if cmd.returncode == 0:
    if wazuh_enabled and role=='analyst':
      try:
        rem_wl(ip)
        print(f'Removed whitelist entry for {ip} from {WAZUH_CONF}', file=sys.stderr)
      except Exception as e:
        print(f'Failed to remove whitelist entry for {ip} from {WAZUH_CONF}', file=sys.stderr)
        print(e)
        return 1
      print('Restarting OSSEC Server...')
      cmd = subprocess.run(restart_wazuh_cmd)
    else:
      return cmd.returncode
  else:
    print(f'Commmand \'{" ".join(salt_cmd)}\' failed.', file=sys.stderr)
    return cmd.returncode
  if cmd.returncode != 0:
    print('Failed to restart OSSEC server.')
  return cmd.returncode


def main():
  if os.geteuid() != 0:
    print('You must run this script as root', file=sys.stderr)
    sys.exit(1)

  main_parser = argparse.ArgumentParser(
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=textwrap.dedent(f'''\
        additional information:
                      To use this script in interactive mode call it with no arguments
        '''
  ))

  group = main_parser.add_argument_group(title='roles')
  group.add_argument('-a', dest='roles', action='append_const', const=VALID_ROLES['a']['role'], help="Analyst - 80/tcp, 443/tcp")
  group.add_argument('-b', dest='roles', action='append_const', const=VALID_ROLES['b']['role'], help="Logstash Beat - 5044/tcp")
  group.add_argument('-e', dest='roles', action='append_const', const=VALID_ROLES['e']['role'], help="Elasticsearch REST API - 9200/tcp")
  group.add_argument('-f', dest='roles', action='append_const', const=VALID_ROLES['f']['role'], help="Strelka frontend - 57314/tcp")
  group.add_argument('-o', dest='roles', action='append_const', const=VALID_ROLES['o']['role'], help="Osquery endpoint - 8090/tcp")
  group.add_argument('-s', dest='roles', action='append_const', const=VALID_ROLES['s']['role'], help="Syslog device - 514/tcp/udp")
  group.add_argument('-w', dest='roles', action='append_const', const=VALID_ROLES['w']['role'], help="Wazuh agent - 1514/tcp/udp")
  group.add_argument('-p', dest='roles', action='append_const', const=VALID_ROLES['p']['role'], help="Wazuh API - 55000/tcp")
  group.add_argument('-r', dest='roles', action='append_const', const=VALID_ROLES['r']['role'], help="Wazuh registration service - 1515/tcp")
  
  ip_g = main_parser.add_argument_group(title='allow')
  ip_g.add_argument('-i', help="IP or CIDR block to disallow connections from, requires at least one role argument", metavar='', dest='ip')

  args = main_parser.parse_args(sys.argv[1:])

  if args.roles is None:
    role = role_prompt()
    ip = ip_prompt()
    try:
      return_code = apply(role, ip)
    except Exception as e:
      print(f'Unexpected exception occurred: {e}', file=sys.stderr)
      return_code = e.errno
    sys.exit(return_code)
  elif args.roles is not None and args.ip is None:
    if os.environ.get('IP') is None:
      main_parser.print_help()
      sys.exit(1)
    else:
      args.ip = os.environ['IP']
  
  if validate_ip_cidr(args.ip):
    try:
      for role in args.roles:
        return_code = apply(role, args.ip)
        if return_code > 0:
          break
    except Exception as e:
      print(f'Unexpected exception occurred: {e}', file=sys.stderr)
      return_code = e.errno
  else:
    print(f'Invalid IP address or CIDR block \'{args.ip}\', please try again.', file=sys.stderr)
    return_code = 1
    
  sys.exit(return_code)


if __name__ == '__main__':
  try:
    main()
  except KeyboardInterrupt:
    sys.exit(1)
