From 23be399a680c5d9c7f06755c9ad857e0ca2bdb22 Mon Sep 17 00:00:00 2001 From: Jason Ertel Date: Wed, 10 Jun 2020 15:16:32 -0400 Subject: [PATCH] Ensure host doesn't exist in both include and exclude lists; add support for port management; add support for removing host from exclude list --- salt/common/tools/sbin/so-firewall | 211 +++++++++++++++++++++++++---- 1 file changed, 184 insertions(+), 27 deletions(-) diff --git a/salt/common/tools/sbin/so-firewall b/salt/common/tools/sbin/so-firewall index 56b07e2f2..d87fd847e 100755 --- a/salt/common/tools/sbin/so-firewall +++ b/salt/common/tools/sbin/so-firewall @@ -19,16 +19,29 @@ import sys import yaml hostgroupsFilename = "/opt/so/saltstack/local/salt/firewall/hostgroups.local.yaml" +portgroupsFilename = "/opt/so/saltstack/local/salt/firewall/portgroups.local.yaml" +supportedProtocols = ['tcp', 'udp'] def showUsage(args): print('Usage: {} [ARGS...]'.format(sys.argv[0])) - print(' Available commands:'); - print(' help - Prints this usage information.'); - print(' included - Lists the IPs included in the given hostgroup. Args: '); - print(' excluded - Lists the IPs excluded from the given hostgroup. Args: '); - print(' include - Adds the given IP (or CIDR) to the given hostgroup. Args: '); - print(' exclude - Removes the given IP (or CIDR) from the given hostgroup. Args: '); - print(' addgroup - Adds a new hostgroup. Args: '); + print(' Available commands:') + print(' help - Prints this usage information.') + print(' includedhosts - Lists the IPs included in the given group. Args: ') + print(' excludedhosts - Lists the IPs excluded from the given group. Args: ') + print(' includehost - Includes the given IP in the given group. Args: ') + print(' excludehost - Excludes the given IP from the given group. Args: ') + print(' removehost - Removes an excluded IP from the given group. Args: ') + print(' addhostgroup - Adds a new, custom host group. Args: ') + print(' listports - Lists ports in the given group and protocol. Args: ') + print(' addport - Adds a PORT to the given group. Args: ') + print(' removeport - Removes a PORT from the given group. Args: ') + print(' addportgroup - Adds a new, custom port group. Args: ') + print('') + print(' Where:') + print(' GROUP_NAME - The name of an alias group (Ex: analyst)') + print(' IP - Either a single IP address (Ex: 8.8.8.8) or a CIDR block (Ex: 10.23.0.0/16).') + print(' PORT_PROTOCOL - Must be one of the following: ' + str(supportedProtocols)) + print(' PORT - Either a single numeric port (Ex: 443), or a port range (Ex: 8000:8002).') sys.exit(1) def loadYaml(filename): @@ -42,7 +55,7 @@ def writeYaml(filename, content): def listIps(name, mode): content = loadYaml(hostgroupsFilename) if name not in content['firewall']['hostgroups']: - print('Hostgroup does not exist', file=sys.stderr) + print('Host group does not exist', file=sys.stderr) return 4 hostgroup = content['firewall']['hostgroups'][name] ips = hostgroup['ips'][mode] @@ -54,7 +67,7 @@ def listIps(name, mode): def addIp(name, ip, mode): content = loadYaml(hostgroupsFilename) if name not in content['firewall']['hostgroups']: - print('Hostgroup does not exist', file=sys.stderr) + print('Host group does not exist', file=sys.stderr) return 4 hostgroup = content['firewall']['hostgroups'][name] ips = hostgroup['ips'][mode] @@ -69,12 +82,37 @@ def addIp(name, ip, mode): writeYaml(hostgroupsFilename, content) return 0 -def addgroup(args): +def removeIp(name, ip, mode, silence = False): + content = loadYaml(hostgroupsFilename) + if name not in content['firewall']['hostgroups']: + print('Host group does not exist', file=sys.stderr) + return 4 + hostgroup = content['firewall']['hostgroups'][name] + ips = hostgroup['ips'][mode] + if ips is None: + ips = [] + hostgroup['ips'][mode] = ips + if ip in ips: + ips.remove(ip) + else: + if not silence: + print('IP does not exist', file=sys.stderr) + return 3 + writeYaml(hostgroupsFilename, content) + return 0 + +def createProtocolMap(): + map = {} + for protocol in supportedProtocols: + map[protocol] = [] + return map + +def addhostgroup(args): if len(args) != 1: - print('Missing hostgroup name argument', file=sys.stderr) + print('Missing host group name argument', file=sys.stderr) showUsage(args) - name = args[0] + name = args[1] content = loadYaml(hostgroupsFilename) if name in content['firewall']['hostgroups']: print('Already exists', file=sys.stderr) @@ -83,29 +121,143 @@ def addgroup(args): writeYaml(hostgroupsFilename, content) return 0 -def included(args): +def addportgroup(args): if len(args) != 1: - print('Missing hostgroup name argument', file=sys.stderr) + print('Missing port group name argument', file=sys.stderr) + showUsage(args) + + name = args[0] + content = loadYaml(portgroupsFilename) + ports = content['firewall']['aliases']['ports'] + if ports is None: + ports = {} + content['firewall']['aliases']['ports'] = ports + if name in ports: + print('Already exists', file=sys.stderr) + return 3 + ports[name] = createProtocolMap() + writeYaml(portgroupsFilename, content) + return 0 + +def listports(args): + if len(args) != 2: + print('Missing port group name or port protocol', file=sys.stderr) + showUsage(args) + + name = args[0] + protocol = args[1] + if protocol not in supportedProtocols: + print('Port protocol is not supported', file=sys.stderr) + return 5 + + content = loadYaml(portgroupsFilename) + ports = content['firewall']['aliases']['ports'] + if ports is None: + ports = {} + content['firewall']['aliases']['ports'] = ports + if name not in ports: + print('Port group does not exist', file=sys.stderr) + return 3 + ports = ports[name][protocol] + if ports is not None: + for port in ports: + print(port) + return 0 + +def addport(args): + if len(args) != 3: + print('Missing port group name or port protocol, or port argument', file=sys.stderr) + showUsage(args) + + name = args[0] + protocol = args[1] + port = args[2] + if protocol not in supportedProtocols: + print('Port protocol is not supported', file=sys.stderr) + return 5 + + content = loadYaml(portgroupsFilename) + ports = content['firewall']['aliases']['ports'] + if ports is None: + ports = {} + content['firewall']['aliases']['ports'] = ports + if name not in ports: + print('Port group does not exist', file=sys.stderr) + return 3 + ports = ports[name][protocol] + if ports is None: + ports = [] + content['firewall']['aliases']['ports'][name][protocol] = ports + if port in ports: + print('Already exists', file=sys.stderr) + return 3 + ports.append(port) + writeYaml(portgroupsFilename, content) + return 0 + +def removeport(args): + if len(args) != 3: + print('Missing port group name or port protocol, or port argument', file=sys.stderr) + showUsage(args) + + name = args[0] + protocol = args[1] + port = args[2] + if protocol not in supportedProtocols: + print('Port protocol is not supported', file=sys.stderr) + return 5 + + content = loadYaml(portgroupsFilename) + ports = content['firewall']['aliases']['ports'] + if ports is None: + ports = {} + content['firewall']['aliases']['ports'] = ports + if name not in ports: + print('Port group does not exist', file=sys.stderr) + return 3 + ports = ports[name][protocol] + if ports is None or port not in ports: + print('Port does not exist', file=sys.stderr) + return 3 + ports.remove(port) + writeYaml(portgroupsFilename, content) + return 0 + +def includedhosts(args): + if len(args) != 1: + print('Missing host group name argument', file=sys.stderr) showUsage(args) return listIps(args[0], 'insert') -def excluded(args): +def excludedhosts(args): if len(args) != 1: - print('Missing hostgroup name argument', file=sys.stderr) + print('Missing host group name argument', file=sys.stderr) showUsage(args) return listIps(args[0], 'delete') -def include(args): +def includehost(args): if len(args) != 2: - print('Missing hostgroup name or ip argument', file=sys.stderr) + print('Missing host group name or ip argument', file=sys.stderr) showUsage(args) - return addIp(args[0], args[1], 'insert') + result = addIp(args[0], args[1], 'insert') + if result == 0: + removeIp(args[0], args[1], 'delete', True) + return result -def exclude(args): +def excludehost(args): if len(args) != 2: - print('Missing hostgroup name or ip argument', file=sys.stderr) + print('Missing host group name or ip argument', file=sys.stderr) showUsage(args) - return addIp(args[0], args[1], 'delete') + result = addIp(args[0], args[1], 'delete') + if result == 0: + removeIp(args[0], args[1], 'insert', True) + return result + +def removehost(args): + if len(args) != 2: + print('Missing host group name or ip argument', file=sys.stderr) + showUsage(args) + return removeIp(args[0], args[1], 'delete') def main(): args = sys.argv[1:] @@ -114,11 +266,16 @@ def main(): commands = { "help": showUsage, - "included": included, - "excluded": excluded, - "include": include, - "exclude": exclude, - "addgroup": addgroup + "includedhosts": includedhosts, + "excludedhosts": excludedhosts, + "includehost": includehost, + "excludehost": excludehost, + "removehost": removehost, + "listports": listports, + "addport": addport, + "removeport": removeport, + "addhostgroup": addhostgroup, + "addportgroup": addportgroup } cmd = commands.get(args[0], showUsage)