#!/usr/bin/env python3

# Copyright Security Onion Solutions LLC and/or licensed to Security Onion Solutions LLC under one
# or more contributor license agreements. Licensed under the Elastic License 2.0 as shown at 
# https://securityonion.net/license; you may not use this file except in compliance with the
# Elastic License 2.0.

import os
import subprocess
import sys
import time
import yaml
import logging

# Configure logging to both file and console
logger = logging.getLogger('so-firewall')
logger.setLevel(logging.INFO)

# File handler
file_handler = logging.FileHandler('/opt/so/log/so-firewall.log')
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)

# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter('%(levelname)s - %(message)s'))
logger.addHandler(console_handler)

lockFile = "/tmp/so-firewall.lock"
hostgroupsFilename = "/opt/so/saltstack/local/pillar/firewall/soc_firewall.sls"
defaultsFilename = "/opt/so/saltstack/default/salt/firewall/defaults.yaml"

def showUsage(options, args):
  usage = f'''Usage: {sys.argv[0]} [OPTIONS] <COMMAND> [ARGS...]
  Options:
   --apply         - After updating the firewall configuration files, apply the new firewall state with queue=True

  General commands:
    help           - Prints this usage information.
    apply          - Apply the firewall state.

  Host commands:
    includehost    - Includes the given IP in the given group. Args: <GROUP_NAME> <IP>
    removehost     - Removes the given IP from all hostgroups. Args: <IP>
    addhostgroup   - Adds a new, custom host group. Args: <GROUP_NAME>

  Where:
   GROUP_NAME    - The name of an alias group (Ex: analyst)
   IP            - Either a single IP address (Ex: 8.8.8.8) or a CIDR block (Ex: 10.23.0.0/16).'''
  logger.error(usage)
  sys.exit(1)

def checkApplyOption(options):
  if "--apply" in options:
    return apply(None, None)

def loadYaml(filename):
  file = open(filename, "r")
  content = file.read()
  return yaml.safe_load(content)

def writeYaml(filename, content):
  file = open(filename, "w")
  return yaml.dump(content, file)

def addIp(name, ip):
  content = loadYaml(hostgroupsFilename)
  defaults = loadYaml(defaultsFilename)
  allowedHostgroups = defaults['firewall']['hostgroups']
  unallowedHostgroups = ['anywhere', 'dockernet', 'localhost', 'self']
  for hg in unallowedHostgroups:
    allowedHostgroups.pop(hg)
  if not content:
    content = {'firewall': {'hostgroups': {name: []}}}
  if name in allowedHostgroups:
    if name not in content['firewall']['hostgroups']:
      hostgroup = content['firewall']['hostgroups'].update({name: [ip]})
    else:
      hostgroup = content['firewall']['hostgroups'][name]
  else:
    logger.error(f"Host group {name} not defined in defaults or is unallowed")
    return 4
  ips = hostgroup
  if ips is None:
    ips = []
    hostgroup = ips
  if ip not in ips:
    ips.append(ip)
    writeYaml(hostgroupsFilename, content)
    logger.info(f"Successfully added IP {ip} to hostgroup {name}")
  else:
    logger.warning(f"IP {ip} already exists in hostgroup {name}")
    return 3
  return 0

def includehost(options, args):
  if len(args) != 2:
    logger.error('Missing host group name or ip argument')
    showUsage(options, args)
  result = addIp(args[0], args[1])
  code = result
  if code == 0:
    code = checkApplyOption(options)
  return code

def apply(options, args):
  logger.info("Applying firewall configuration changes")
  salt_args = ['salt-call', 'state.apply', 'firewall', 'queue=True']
  proc = subprocess.run(salt_args)
  if proc.returncode != 0:
    logger.error("Failed to apply firewall changes")
  else:
    logger.info("Successfully applied firewall changes")
  return proc.returncode

def removehost(options, args):
  """Remove an IP from all hostgroups and apply changes if requested"""
  if len(args) != 1:
    logger.error('Missing IP argument')
    showUsage(options, args)
  
  ip = args[0]
  content = loadYaml(hostgroupsFilename)
  if not content or 'firewall' not in content or 'hostgroups' not in content['firewall']:
    logger.error("Invalid firewall configuration structure")
    return 4
    
  modified = False
  removed_from = []
  for group_name, ips in content['firewall']['hostgroups'].items():
    if ips and ip in ips:
      ips.remove(ip)
      modified = True
      removed_from.append(group_name)
  
  if modified:
    writeYaml(hostgroupsFilename, content)
    logger.info(f"Successfully removed IP {ip} from hostgroups: {', '.join(removed_from)}")
    if "--apply" in options:
      return apply(None, None)
  else:
    logger.error(f"IP {ip} not found in any hostgroups")
  
  return 0

def main():
  options = []
  args = sys.argv[1:]
  for option in args:
    if option.startswith("--"):
      options.append(option)
      args.remove(option)

  if len(args) == 0:
    showUsage(options, None)

  commands = {
    "help": showUsage,
    "includehost": includehost,
    "removehost": removehost,
    "apply": apply
  }

  code=1

  try:
    lockAttempts = 0
    maxAttempts = 30
    while lockAttempts < maxAttempts:
      lockAttempts = lockAttempts + 1
      try:
        f = open(lockFile, "x")
        f.close()
        break
      except:
        time.sleep(2)

    if lockAttempts == maxAttempts:
      logger.error(f"Lock file ({lockFile}) could not be created - proceeding without lock")

    cmd = commands.get(args[0], showUsage)
    code = cmd(options, args[1:])
  finally:
    try:
      os.remove(lockFile)
    except:
      logger.error(f"Lock file ({lockFile}) already removed")

  sys.exit(code)

if __name__ == "__main__":
  main()
