#!/usr/bin/env python3 # Copyright 2014-2022 Security Onion Solutions, LLC # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import os import re import subprocess import sys import time import yaml lockFile = "/tmp/so-firewall.lock" hostgroupsFilename = "/opt/so/saltstack/local/salt/firewall/hostgroups.local.yaml" portgroupsFilename = "/opt/so/saltstack/local/salt/firewall/portgroups.local.yaml" defaultPortgroupsFilename = "/opt/so/saltstack/default/salt/firewall/portgroups.yaml" supportedProtocols = ['tcp', 'udp'] readonly = False def showUsage(options, args): print('Usage: {} [OPTIONS] [ARGS...]'.format(sys.argv[0])) print(' Options:') print(' --apply - After updating the firewall configuration files, apply the new firewall state') print(' --defaultports - Read port groups from default configuration files instead of local configuration.') print('') print(' General commands:') print(' help - Prints this usage information.') print(' apply - Apply the firewall state.') print('') print(' Host commands:') print(' listhostgroups - Lists the known host groups.') print(' includedhosts - Lists the IPs included in the given group. Args: ') print(' excludedhosts - Lists the IPs excluded from the given group. Args: ') print(' includehost - Includes the given IP in the given group. Args: ') print(' excludehost - Excludes the given IP from the given group. Args: ') print(' removehost - Removes an excluded IP from the given group. Args: ') print(' addhostgroup - Adds a new, custom host group. Args: ') print('') print(' Port commands:') print(' listportgroups - Lists the known port groups.') print(' listports - Lists ports in the given group and protocol. Args: ') print(' addport - Adds a PORT to the given group. Args: ') print(' removeport - Removes a PORT from the given group. Args: ') print(' addportgroup - Adds a new, custom port group. Args: ') print('') print(' Where:') print(' GROUP_NAME - The name of an alias group (Ex: analyst)') print(' IP - Either a single IP address (Ex: 8.8.8.8) or a CIDR block (Ex: 10.23.0.0/16).') print(' PORT_PROTOCOL - Must be one of the following: ' + str(supportedProtocols)) print(' PORT - Either a single numeric port (Ex: 443), or a port range (Ex: 8000:8002).') sys.exit(1) def checkDefaultPortsOption(options): global portgroupsFilename if "--defaultports" in options: portgroupsFilename = defaultPortgroupsFilename def checkApplyOption(options): if "--apply" in options: return apply(None, None) def loadYaml(filename): global readonly file = open(filename, "r") content = file.read() # Remove Jinja templating (for read-only operations) if "{%" in content or "{{" in content: content = content.replace("{{ ssh_port }}", "22") pattern = r'.*({%|{{|}}|%}).*' content = re.sub(pattern, "", content) readonly = True return yaml.safe_load(content) def writeYaml(filename, content): global readonly if readonly: raise Exception("Cannot write yaml file that has been flagged as read-only") file = open(filename, "w") return yaml.dump(content, file) def listHostGroups(): content = loadYaml(hostgroupsFilename) hostgroups = content['firewall']['hostgroups'] if hostgroups is not None: for group in hostgroups: print(group) return 0 def listIps(name, mode): content = loadYaml(hostgroupsFilename) if name not in content['firewall']['hostgroups']: print('Host group does not exist', file=sys.stderr) return 4 hostgroup = content['firewall']['hostgroups'][name] ips = hostgroup['ips'][mode] if ips is not None: for ip in ips: print(ip) return 0 def addIp(name, ip, mode): content = loadYaml(hostgroupsFilename) if name not in content['firewall']['hostgroups']: print('Host group does not exist', file=sys.stderr) return 4 hostgroup = content['firewall']['hostgroups'][name] ips = hostgroup['ips'][mode] if ips is None: ips = [] hostgroup['ips'][mode] = ips if ip not in ips: ips.append(ip) else: print('Already exists', file=sys.stderr) return 3 writeYaml(hostgroupsFilename, content) return 0 def removeIp(name, ip, mode, silence = False): content = loadYaml(hostgroupsFilename) if name not in content['firewall']['hostgroups']: print('Host group does not exist', file=sys.stderr) return 4 hostgroup = content['firewall']['hostgroups'][name] ips = hostgroup['ips'][mode] if ips is None: ips = [] hostgroup['ips'][mode] = ips if ip in ips: ips.remove(ip) else: if not silence: print('IP does not exist', file=sys.stderr) return 3 writeYaml(hostgroupsFilename, content) return 0 def createProtocolMap(): map = {} for protocol in supportedProtocols: map[protocol] = [] return map def listPortGroups(): content = loadYaml(portgroupsFilename) portgroups = content['firewall']['aliases']['ports'] if portgroups is not None: for group in portgroups: print(group) return 0 def addhostgroup(options, args): if len(args) != 1: print('Missing host group name argument', file=sys.stderr) showUsage(options, args) name = args[0] content = loadYaml(hostgroupsFilename) if name in content['firewall']['hostgroups']: print('Already exists', file=sys.stderr) return 3 content['firewall']['hostgroups'][name] = { 'ips': { 'insert': [], 'delete': [] }} writeYaml(hostgroupsFilename, content) return 0 def listportgroups(options, args): if len(args) != 0: print('Unexpected arguments', file=sys.stderr) showUsage(options, args) checkDefaultPortsOption(options) return listPortGroups() def addportgroup(options, args): if len(args) != 1: print('Missing port group name argument', file=sys.stderr) showUsage(options, args) name = args[0] content = loadYaml(portgroupsFilename) ports = content['firewall']['aliases']['ports'] if ports is None: ports = {} content['firewall']['aliases']['ports'] = ports if name in ports: print('Already exists', file=sys.stderr) return 3 ports[name] = createProtocolMap() writeYaml(portgroupsFilename, content) return 0 def listports(options, args): if len(args) != 2: print('Missing port group name or port protocol', file=sys.stderr) showUsage(options, args) checkDefaultPortsOption(options) name = args[0] protocol = args[1] if protocol not in supportedProtocols: print('Port protocol is not supported', file=sys.stderr) return 5 content = loadYaml(portgroupsFilename) ports = content['firewall']['aliases']['ports'] if ports is None: ports = {} content['firewall']['aliases']['ports'] = ports if name not in ports: print('Port group does not exist', file=sys.stderr) return 3 if protocol not in ports[name]: print('Port group does not contain protocol', file=sys.stderr) return 3 ports = ports[name][protocol] if ports is not None: for port in ports: print(port) return 0 def addport(options, args): if len(args) != 3: print('Missing port group name or port protocol, or port argument', file=sys.stderr) showUsage(options, args) name = args[0] protocol = args[1] port = args[2] if protocol not in supportedProtocols: print('Port protocol is not supported', file=sys.stderr) return 5 content = loadYaml(portgroupsFilename) ports = content['firewall']['aliases']['ports'] if ports is None: ports = {} content['firewall']['aliases']['ports'] = ports if name not in ports: print('Port group does not exist', file=sys.stderr) return 3 ports = ports[name][protocol] if ports is None: ports = [] content['firewall']['aliases']['ports'][name][protocol] = ports if port in ports: print('Already exists', file=sys.stderr) return 3 ports.append(port) writeYaml(portgroupsFilename, content) code = checkApplyOption(options) return code def removeport(options, args): if len(args) != 3: print('Missing port group name or port protocol, or port argument', file=sys.stderr) showUsage(options, args) name = args[0] protocol = args[1] port = args[2] if protocol not in supportedProtocols: print('Port protocol is not supported', file=sys.stderr) return 5 content = loadYaml(portgroupsFilename) ports = content['firewall']['aliases']['ports'] if ports is None: ports = {} content['firewall']['aliases']['ports'] = ports if name not in ports: print('Port group does not exist', file=sys.stderr) return 3 ports = ports[name][protocol] if ports is None or port not in ports: print('Port does not exist', file=sys.stderr) return 3 ports.remove(port) writeYaml(portgroupsFilename, content) code = checkApplyOption(options) return code def listhostgroups(options, args): if len(args) != 0: print('Unexpected arguments', file=sys.stderr) showUsage(options, args) return listHostGroups() def includedhosts(options, args): if len(args) != 1: print('Missing host group name argument', file=sys.stderr) showUsage(options, args) return listIps(args[0], 'insert') def excludedhosts(options, args): if len(args) != 1: print('Missing host group name argument', file=sys.stderr) showUsage(options, args) return listIps(args[0], 'delete') def includehost(options, args): if len(args) != 2: print('Missing host group name or ip argument', file=sys.stderr) showUsage(options, args) result = addIp(args[0], args[1], 'insert') if result == 0: removeIp(args[0], args[1], 'delete', True) code = result if code == 0: code = checkApplyOption(options) return code def excludehost(options, args): if len(args) != 2: print('Missing host group name or ip argument', file=sys.stderr) showUsage(options, args) result = addIp(args[0], args[1], 'delete') if result == 0: removeIp(args[0], args[1], 'insert', True) code = result if code == 0: code = checkApplyOption(options) return code def removehost(options, args): if len(args) != 2: print('Missing host group name or ip argument', file=sys.stderr) showUsage(options, args) code = removeIp(args[0], args[1], 'delete') if code == 0: code = checkApplyOption(options) return code def apply(options, args): proc = subprocess.run(['salt-call', 'state.apply', 'firewall', 'queue=True']) return proc.returncode def main(): options = [] args = sys.argv[1:] for option in args: if option.startswith("--"): options.append(option) args.remove(option) if len(args) == 0: showUsage(options, None) commands = { "help": showUsage, "listhostgroups": listhostgroups, "includedhosts": includedhosts, "excludedhosts": excludedhosts, "includehost": includehost, "excludehost": excludehost, "removehost": removehost, "listportgroups": listportgroups, "listports": listports, "addport": addport, "removeport": removeport, "addhostgroup": addhostgroup, "addportgroup": addportgroup, "apply": apply } code=1 try: lockAttempts = 0 maxAttempts = 30 while lockAttempts < maxAttempts: lockAttempts = lockAttempts + 1 try: f = open(lockFile, "x") f.close() break except: time.sleep(2) if lockAttempts == maxAttempts: print("Lock file (" + lockFile + ") could not be created; proceeding without lock.") cmd = commands.get(args[0], showUsage) code = cmd(options, args[1:]) finally: try: os.remove(lockFile) except: print("Lock file (" + lockFile + ") already removed") sys.exit(code) if __name__ == "__main__": main()