#!/usr/bin/env python # -*- coding: utf-8 -*- from pysnmp.hlapi import * from pysnmp.smi.builder import MibBuilder from pysnmp.smi.view import MibViewController import importlib import itertools import mock import random import unittest _mbuilder = MibBuilder() _mvc = MibViewController(_mbuilder) #import data # received packages # pvid: dot1qPvid # # tx packets: # dot1qVlanStaticEgressPorts # dot1qVlanStaticUntaggedPorts # # vlans: # dot1qVlanCurrentTable # lists ALL vlans, including baked in ones # # note that even though an snmpwalk of dot1qVlanStaticEgressPorts # skips over other vlans (only shows statics), the other vlans (1,2,3) # are still accessible via that oid # # LLDP: # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable class SwitchConfig(object): def __init__(self, host, community, vlanconf, ignports): self._host = host self._community = community self._vlanconf = vlanconf self._ignports = ignports @property def host(self): return self._host @property def community(self): return self._community @property def vlanconf(self): return self._vlanconf @property def ignports(self): return self._ignports def getportlist(self, lookupfun): '''Return a set of all the ports indexes in data.''' res = set() for id in self._vlanconf: res.update(self._vlanconf[id].get('u', [])) res.update(self._vlanconf[id].get('t', [])) # add in the ignore ports res.update(self.ignports) # filter out the strings strports = set(x for x in res if isinstance(x, str)) res.update(lookupfun(x) for x in strports) res.difference_update(strports) return res def _octstrtobits(os): num = 1 for i in str(os): num = (num << 8) | ord(i) return bin(num)[3:] def _intstobits(*ints): v = 0 for i in ints: v |= 1 << i r = list(bin(v)[2:-1]) r.reverse() return ''.join(r) def _cmpbits(a, b): try: last1a = a.rindex('1') except ValueError: last1a = -1 a = '' try: last1b = b.rindex('1') except ValueError: last1b = -1 b = '' if last1a != -1: a = a[:last1a + 1] if last1b != -1: b = b[:last1b + 1] return a == b import vlanmang def checkchanges(module): mod = importlib.import_module(module) mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, vlanmang.SwitchConfig) ] res = [] for i in mods: vlans = i.vlanconf.keys() switch = SNMPSwitch(i.host, i.community) portmapping = switch.getportmapping() invportmap = { y: x for x, y in portmapping.iteritems() } lufun = invportmap.__getitem__ # get complete set of ports portlist = i.getportlist(lufun) ports = set(portmapping.iterkeys()) # make sure switch agrees w/ them all if ports != portlist: raise ValueError('missing or extra ports found: %s' % `ports.symmetric_difference(portlist)`) # compare pvid pvidmap = getpvidmapping(i.vlanconf, lufun) switchpvid = switch.getpvid() res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in pvidmap.iteritems() if switchpvid[idx] != vlan) # compare egress & untagged switchegress = switch.getegress(*vlans) egress = getegress(i.vlanconf, lufun) switchuntagged = switch.getuntagged(*vlans) untagged = getuntagged(i.vlanconf, lufun) for i in vlans: if not _cmpbits(switchegress[i], egress[i]): res.append(('setegress', i, egress[i], switchegress[i])) if not _cmpbits(switchuntagged[i], untagged[i]): res.append(('setuntagged', i, untagged[i], switchuntagged[i])) return res, switch def getidxs(lst, lookupfun): return [ lookupfun(i) if isinstance(i, str) else i for i in lst ] def getpvidmapping(data, lookupfun): '''Return a mapping from vlan based table to a port: vlan dictionary.''' res = [] for id in data: for i in data[id].get('u', []): if isinstance(i, str): i = lookupfun(i) res.append((i, id)) return dict(res) def getegress(data, lookupfun): r = {} for id in data: r[id] = _intstobits(*(getidxs(data[id].get('u', []), lookupfun) + getidxs(data[id].get('t', []), lookupfun))) return r def getuntagged(data, lookupfun): r = {} for id in data: r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun)) return r class SNMPSwitch(object): '''A class for manipulating switches via standard SNMP MIBs.''' def __init__(self, host, auth): self._eng = SnmpEngine() if isinstance(auth, str): self._cd = CommunityData(auth, mpModel=0) else: self._cd = auth self._targ = UdpTransportTarget((host, 161)) def _getmany(self, *oids): woids = [ ObjectIdentity(*oid) for oid in oids ] [ oid.resolveWithMib(_mvc) for oid in woids ] errorInd, errorStatus, errorIndex, varBinds = \ next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in woids))) if errorInd: # pragma: no cover raise ValueError(errorIndication) elif errorStatus: if str(errorStatus) == 'tooBig' and len(oids) > 1: # split the request in two pivot = len(oids) / 2 a = self._getmany(*oids[:pivot]) b = self._getmany(*oids[pivot:]) return a + b raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: if len(varBinds) != len(oids): # pragma: no cover raise ValueError('too many return values') return varBinds def _get(self, oid): varBinds = self._getmany(oid) varBind = varBinds[0] return varBind[1] def _set(self, oid, value): oid = ObjectIdentity(*oid) oid.resolveWithMib(_mvc) if isinstance(value, (int, long)): value = Integer(value) elif isinstance(value, str): value = OctetString(value) errorInd, errorStatus, errorIndex, varBinds = \ next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value))) if errorInd: # pragma: no cover raise ValueError(errorIndication) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: for varBind in varBinds: if varBind[1] != value: # pragma: no cover raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind])) def _walk(self, *oid): oid = ObjectIdentity(*oid) # XXX - keep these, this might stop working, no clue what managed to magically make things work # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs') oid.resolveWithMib(_mvc) for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd( self._eng, self._cd, self._targ, ContextData(), ObjectType(oid), lexicographicMode=False): if errorInd: # pragma: no cover raise ValueError(errorInd) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: for varBind in varBinds: yield varBind def getportmapping(self): '''Return a port name mapping. Keys are the port index and the value is the name from the ifName entry.''' return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') } def findport(self, name): '''Look up a port name and return it's port index. This looks up via the ifName table in IF-MIB.''' return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0] def getvlanname(self, vlan): '''Return the name for the vlan.''' v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan)) return str(v).decode('utf-8') def createvlan(self, vlan, name): # createAndGo(4) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus', int(vlan)), 4) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)), name) def deletevlan(self, vlan): self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus', int(vlan)), 6) # destroy(6) def getvlans(self): '''Return an iterator with all the vlan ids.''' return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus')) def staticvlans(self): '''Return an iterator of the staticly defined/configured vlans. This sometimes excludes special built in vlans, like vlan 1.''' return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName')) def getpvid(self): '''Returns a dictionary w/ the interface index as the key, and the pvid of the interface.''' return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') } def setpvid(self, port, vlan): self._set(('Q-BRIDGE-MIB', 'dot1qPvid', int(port)), Gauge32(vlan)) def getegress(self, *vlans): r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts', x) for x in vlans)) } return r def setegress(self, vlan, ports): value = OctetString.fromBinaryString(ports) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts', int(vlan)), value) def getuntagged(self, *vlans): r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) } return r def setuntagged(self, vlan, ports): value = OctetString.fromBinaryString(ports) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts', int(vlan)), value) if __name__ == '__main__': # pragma: no cover import pprint import sys changes, switch = checkchanges('data') if not changes: print 'No changes to apply.' sys.exit(0) pprint.pprint(changes) res = raw_input('Apply the changes? (type yes to apply): ') if res != 'yes': print 'not applying changes.' sys.exit(1) print 'applying...' failed = [] for verb, arg1, arg2, oldarg in changes: print '%s: %s %s' % (verb, arg1, `arg2`) try: fun = getattr(switch, verb) fun(arg1, arg2) pass except Exception as e: print 'failed' failed.append((verb, arg1, arg2, e)) if failed: print '%d failed to apply, they are:' % len(failed) for verb, arg1, arg2, e in failed: print '%s: %s %s: %s' % (verb, arg1, arg2, `e`) class _TestMisc(unittest.TestCase): def setUp(self): import test_data self._test_data = test_data def test_intstobits(self): self.assertEqual(_intstobits(1, 5, 10), '1000100001') self.assertEqual(_intstobits(3, 4, 9), '001100001') def test_octstrtobits(self): self.assertEqual(_octstrtobits('\x00'), '0' * 8) self.assertEqual(_octstrtobits('\xff'), '1' * 8) self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4) self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4) def test_cmpbits(self): self.assertTrue(_cmpbits('111000', '111')) self.assertTrue(_cmpbits('000111000', '000111')) self.assertTrue(_cmpbits('11', '11')) self.assertTrue(_cmpbits('0', '000')) self.assertFalse(_cmpbits('0011', '11')) self.assertFalse(_cmpbits('11', '0011')) self.assertFalse(_cmpbits('10', '000')) self.assertFalse(_cmpbits('0', '1000')) self.assertFalse(_cmpbits('00010', '000')) self.assertFalse(_cmpbits('0', '001000')) def test_pvidegressuntagged(self): data = { 1: { 'u': [ 1, 5, 10 ] + range(13, 20), 't': [ 'lag2', 6, 7 ], }, 10: { 'u': [ 2, 3, 6, 7, 8, 'lag2' ], }, 13: { 'u': [ 4, 9 ], 't': [ 'lag2', 6, 7 ], }, 14: { 't': [ 'lag2' ], }, } swconf = SwitchConfig('', '', data, [ 'lag3' ]) lookup = { 'lag2': 30, 'lag3': 31, } lufun = lookup.__getitem__ check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10, 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13), [ (30, 10) ])) # That a pvid mapping res = getpvidmapping(data, lufun) # is correct self.assertEqual(res, check) self.assertEqual(swconf.getportlist(lufun), set(xrange(1, 11)) | set(xrange(13, 20)) | set(lookup.values())) checkegress = { 1: '1000111001001111111' + '0' * (30 - 20) + '1', 10: '01100111' + '0' * (30 - 9) + '1', 13: '000101101' + '0' * (30 - 10) + '1', 14: '0' * (30 - 1) + '1', } self.assertEqual(getegress(data, lufun), checkegress) checkuntagged = { 1: '1000100001001111111', 10: '01100111' + '0' * (30 - 9) + '1', 13: '000100001', 14: '', } self.assertEqual(getuntagged(data, lufun), checkuntagged) #@unittest.skip('foo') @mock.patch('vlanmang.SNMPSwitch.getuntagged') @mock.patch('vlanmang.SNMPSwitch.getegress') @mock.patch('vlanmang.SNMPSwitch.getpvid') @mock.patch('vlanmang.SNMPSwitch.getportmapping') @mock.patch('importlib.import_module') def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged): # that import returns the test data imprt.side_effect = itertools.repeat(self._test_data) # that getportmapping returns the following dict ports = { x: 'g%d' % x for x in xrange(1, 24) } ports[30] = 'lag1' ports[31] = 'lag2' ports[32] = 'lag3' portmapping.side_effect = itertools.repeat(ports) # that the switch's pvid returns spvid = { x: 283 for x in xrange(1, 24) } spvid[30] = 5 gpvid.side_effect = itertools.repeat(spvid) # the the extra port is caught self.assertRaises(ValueError, checkchanges, 'data') # that the functions were called imprt.assert_called_with('data') portmapping.assert_called() # XXX - check that an ignore statement is honored # delete the extra port del ports[32] # that the egress data provided gegress.side_effect = [ { 1: '1' * 10, 5: '1' * 10, 283: '00000000111111111110011000000100000', } ] # that the untagged data provided guntagged.side_effect = [ { 1: '1' * 10, 5: '1' * 8 + '0' * 10, 283: '00000000111111111110011', } ] res, switch = checkchanges('data') self.assertIsInstance(switch, SNMPSwitch) validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \ [ ('setpvid', 20, 1, 283), ('setpvid', 21, 1, 283), ('setpvid', 30, 1, 5), ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10), ] self.assertEqual(set(res), set(validres)) class _TestSNMPSwitch(unittest.TestCase): def test_splitmany(self): # make sure that if we get a tooBig error that we split the # _getmany request switch = SNMPSwitch(None, None) @mock.patch('vlanmang.SNMPSwitch._getmany') def test_get(self, gm): # that a switch switch = SNMPSwitch(None, None) # when _getmany returns this structure retval = object() gm.side_effect = [[[ None, retval ]]] arg = object() # will return the correct value self.assertIs(switch._get(arg), retval) # and call _getmany w/ the correct arg gm.assert_called_with(arg) _skipSwitchTests = True class _TestSwitch(unittest.TestCase): def setUp(self): # If we don't have it, pretend it's true for now and # we'll recheck it later model = 'GS108T smartSwitch' if getattr(self, 'switchmodel', model) != model or \ _skipSwitchTests: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') args = open('test.creds').read().split() self.switch = SNMPSwitch(*args) self.switchmodel = self.switch._get(('ENTITY-MIB', 'entPhysicalModelName', 1)) if self.switchmodel != model: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') def test_misc(self): switch = self.switch self.assertEqual(switch.findport('g1'), 1) self.assertEqual(switch.findport('l1'), 14) def test_portnames(self): switch = self.switch resp = dict((x, 'g%d' % x) for x in xrange(1, 9)) resp.update({ 13: 'cpu' }) resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18)) self.assertEqual(switch.getportmapping(), resp) def test_egress(self): switch = self.switch egress = switch.getegress(1, 2, 3) checkegress = { 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23, 2: '0' * 8 * 5, 3: '0' * 8 * 5, } self.assertEqual(egress, checkegress) def test_untagged(self): switch = self.switch untagged = switch.getuntagged(1, 2, 3) checkuntagged = { 1: '1' * 8 * 5, 2: '1' * 8 * 5, 3: '1' * 8 * 5, } self.assertEqual(untagged, checkuntagged) def test_vlan(self): switch = self.switch existingvlans = set(switch.getvlans()) while True: testvlan = random.randint(1,4095) if testvlan not in existingvlans: break # Test that getting a non-existant vlans raises an exception self.assertRaises(ValueError, switch.getvlanname, testvlan) self.assertTrue(set(switch.staticvlans()).issubset(existingvlans)) pvidres = { x: 1 for x in xrange(1, 9) } pvidres.update({ x: 1 for x in xrange(14, 18) }) self.assertEqual(switch.getpvid(), pvidres) testname = 'Sometestname' # Create test vlan switch.createvlan(testvlan, testname) testport = None try: # make sure the test vlan was created self.assertIn(testvlan, set(switch.staticvlans())) self.assertEqual(testname, switch.getvlanname(testvlan)) switch.setegress(testvlan, '00100') pvidmap = switch.getpvid() testport = 3 egressports = switch.getegress(testvlan) self.assertEqual(egressports[testvlan], '00100000' + '0' * 8 * 4) switch.setuntagged(testvlan, '00100') untaggedports = switch.getuntagged(testvlan) self.assertEqual(untaggedports[testvlan], '00100000' + '0' * 8 * 4) switch.setpvid(testport, testvlan) self.assertEqual(switch.getpvid()[testport], testvlan) finally: if testport: switch.setpvid(testport, pvidmap[3]) switch.deletevlan(testvlan)