#!/usr/bin/env python3

# Copyright 2014,2015,2016,2017,2018,2019,2020 Security Onion Solutions, LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import subprocess
import sys
import yaml

hostgroupsFilename = "/opt/so/saltstack/local/salt/firewall/hostgroups.local.yaml"
portgroupsFilename = "/opt/so/saltstack/local/salt/firewall/portgroups.local.yaml"
supportedProtocols = ['tcp', 'udp']

def showUsage(args):
  print('Usage: {} [OPTIONS] <COMMAND> [ARGS...]'.format(sys.argv[0]))
  print('  Options:')
  print('   --apply       - After updating the firewall configuration files, apply the new firewall state')
  print('')
  print('  Available commands:')
  print('   help          - Prints this usage information.')
  print('   includedhosts - Lists the IPs included in the given group. Args: <GROUP_NAME>')
  print('   excludedhosts - Lists the IPs excluded from the given group. Args: <GROUP_NAME>')
  print('   includehost   - Includes the given IP in the given group. Args: <GROUP_NAME> <IP>')
  print('   excludehost   - Excludes the given IP from the given group. Args: <GROUP_NAME> <IP>')
  print('   removehost    - Removes an excluded IP from the given group. Args: <GROUP_NAME> <IP>')
  print('   addhostgroup  - Adds a new, custom host group. Args: <GROUP_NAME>')
  print('   listports     - Lists ports in the given group and protocol. Args: <GROUP_NAME> <PORT_PROTOCOL>')
  print('   addport       - Adds a PORT to the given group. Args: <GROUP_NAME> <PORT_PROTOCOL> <PORT>')
  print('   removeport    - Removes a PORT from the given group. Args: <GROUP_NAME> <PORT_PROTOCOL> <PORT>')
  print('   addportgroup  - Adds a new, custom port group. Args: <GROUP_NAME>')
  print('')
  print('  Where:')
  print('   GROUP_NAME    - The name of an alias group (Ex: analyst)')
  print('   IP            - Either a single IP address (Ex: 8.8.8.8) or a CIDR block (Ex: 10.23.0.0/16).')
  print('   PORT_PROTOCOL - Must be one of the following: ' + str(supportedProtocols))
  print('   PORT          - Either a single numeric port (Ex: 443), or a port range (Ex: 8000:8002).')
  sys.exit(1)

def loadYaml(filename):
  file = open(filename, "r")
  return yaml.load(file.read())

def writeYaml(filename, content):
  file = open(filename, "w")
  return yaml.dump(content, file)

def listIps(name, mode):
  content = loadYaml(hostgroupsFilename)
  if name not in content['firewall']['hostgroups']:
    print('Host group does not exist', file=sys.stderr)
    return 4
  hostgroup = content['firewall']['hostgroups'][name]
  ips = hostgroup['ips'][mode]
  if ips is not None:
    for ip in ips:
      print(ip)
  return 0

def addIp(name, ip, mode):
  content = loadYaml(hostgroupsFilename)
  if name not in content['firewall']['hostgroups']:
    print('Host group does not exist', file=sys.stderr)
    return 4
  hostgroup = content['firewall']['hostgroups'][name]
  ips = hostgroup['ips'][mode]
  if ips is None:
    ips = []
    hostgroup['ips'][mode] = ips
  if ip not in ips:
    ips.append(ip)
  else:
    print('Already exists', file=sys.stderr)
    return 3
  writeYaml(hostgroupsFilename, content)
  return 0

def removeIp(name, ip, mode, silence = False):
  content = loadYaml(hostgroupsFilename)
  if name not in content['firewall']['hostgroups']:
    print('Host group does not exist', file=sys.stderr)
    return 4
  hostgroup = content['firewall']['hostgroups'][name]
  ips = hostgroup['ips'][mode]
  if ips is None:
    ips = []
    hostgroup['ips'][mode] = ips
  if ip in ips:
    ips.remove(ip)
  else:
    if not silence:
      print('IP does not exist', file=sys.stderr)
      return 3
  writeYaml(hostgroupsFilename, content)
  return 0

def createProtocolMap():
  map = {}
  for protocol in supportedProtocols:
    map[protocol] = []
  return map

def addhostgroup(args):
  if len(args) != 1:
    print('Missing host group name argument', file=sys.stderr)
    showUsage(args)

  name = args[0]
  content = loadYaml(hostgroupsFilename)
  if name in content['firewall']['hostgroups']:
    print('Already exists', file=sys.stderr)
    return 3
  content['firewall']['hostgroups'][name] = { 'ips': { 'insert': [], 'delete': [] }}
  writeYaml(hostgroupsFilename, content)
  return 0

def addportgroup(args):
  if len(args) != 1:
    print('Missing port group name argument', file=sys.stderr)
    showUsage(args)

  name = args[0]
  content = loadYaml(portgroupsFilename)
  ports = content['firewall']['aliases']['ports']
  if ports is None:
    ports = {}
    content['firewall']['aliases']['ports'] = ports
  if name in ports:
    print('Already exists', file=sys.stderr)
    return 3
  ports[name] = createProtocolMap()
  writeYaml(portgroupsFilename, content)
  return 0

def listports(args):
  if len(args) != 2:
    print('Missing port group name or port protocol', file=sys.stderr)
    showUsage(args)

  name = args[0]
  protocol = args[1]
  if protocol not in supportedProtocols:
    print('Port protocol is not supported', file=sys.stderr)
    return 5

  content = loadYaml(portgroupsFilename)
  ports = content['firewall']['aliases']['ports']
  if ports is None:
    ports = {}
    content['firewall']['aliases']['ports'] = ports
  if name not in ports:
    print('Port group does not exist', file=sys.stderr)
    return 3
  ports = ports[name][protocol]
  if ports is not None:
    for port in ports:
      print(port)
  return 0

def addport(args):
  if len(args) != 3:
    print('Missing port group name or port protocol, or port argument', file=sys.stderr)
    showUsage(args)

  name = args[0]
  protocol = args[1]
  port = args[2]
  if protocol not in supportedProtocols:
    print('Port protocol is not supported', file=sys.stderr)
    return 5

  content = loadYaml(portgroupsFilename)
  ports = content['firewall']['aliases']['ports']
  if ports is None:
    ports = {}
    content['firewall']['aliases']['ports'] = ports
  if name not in ports:
    print('Port group does not exist', file=sys.stderr)
    return 3
  ports = ports[name][protocol]
  if ports is None:
    ports = []
    content['firewall']['aliases']['ports'][name][protocol] = ports
  if port in ports:
    print('Already exists', file=sys.stderr)
    return 3
  ports.append(port)
  writeYaml(portgroupsFilename, content)
  return 0

def removeport(args):
  if len(args) != 3:
    print('Missing port group name or port protocol, or port argument', file=sys.stderr)
    showUsage(args)

  name = args[0]
  protocol = args[1]
  port = args[2]
  if protocol not in supportedProtocols:
    print('Port protocol is not supported', file=sys.stderr)
    return 5

  content = loadYaml(portgroupsFilename)
  ports = content['firewall']['aliases']['ports']
  if ports is None:
    ports = {}
    content['firewall']['aliases']['ports'] = ports
  if name not in ports:
    print('Port group does not exist', file=sys.stderr)
    return 3
  ports = ports[name][protocol]
  if ports is None or port not in ports:
    print('Port does not exist', file=sys.stderr)
    return 3
  ports.remove(port)
  writeYaml(portgroupsFilename, content)
  return 0

def includedhosts(args):
  if len(args) != 1:
    print('Missing host group name argument', file=sys.stderr)
    showUsage(args)
  return listIps(args[0], 'insert')

def excludedhosts(args):
  if len(args) != 1:
    print('Missing host group name argument', file=sys.stderr)
    showUsage(args)
  return listIps(args[0], 'delete')

def includehost(args):
  if len(args) != 2:
    print('Missing host group name or ip argument', file=sys.stderr)
    showUsage(args)
  result = addIp(args[0], args[1], 'insert')
  if result == 0:
    removeIp(args[0], args[1], 'delete', True)
  return result

def excludehost(args):
  if len(args) != 2:
    print('Missing host group name or ip argument', file=sys.stderr)
    showUsage(args)
  result = addIp(args[0], args[1], 'delete')
  if result == 0:
    removeIp(args[0], args[1], 'insert', True)
  return result

def removehost(args):
  if len(args) != 2:
    print('Missing host group name or ip argument', file=sys.stderr)
    showUsage(args)
  return removeIp(args[0], args[1], 'delete')

def apply():
  proc = subprocess.run(['salt-call', 'state.apply', 'firewall', 'queue=True'])
  return proc.returncode

def main():
  options = []
  args = sys.argv[1:]
  for option in args:
    if option.startswith("--"):
      options.append(option)
      args.remove(option)

  if len(args) == 0:
    showUsage(None)

  commands = {
    "help": showUsage,
    "includedhosts": includedhosts,
    "excludedhosts": excludedhosts,
    "includehost": includehost,
    "excludehost": excludehost,
    "removehost": removehost,
    "listports": listports,
    "addport": addport,
    "removeport": removeport,
    "addhostgroup": addhostgroup,
    "addportgroup": addportgroup
  }
  
  cmd = commands.get(args[0], showUsage)
  code = cmd(args[1:])


  if code == 0 and "--apply" in options:
    code = apply()

  sys.exit(code)

if __name__ == "__main__":
  main()
