| @@ -52,7 +52,14 @@ class SwitchConfig(object): | |||||
| def vlanconf(self): | def vlanconf(self): | ||||
| return self._vlanconf | return self._vlanconf | ||||
| def intstobits(*ints): | |||||
| def _octstrtobits(os): | |||||
| num = 1 | |||||
| for i in str(os): | |||||
| num = (num << 8) | ord(i) | |||||
| return bin(num)[3:] | |||||
| def _intstobits(*ints): | |||||
| v = 0 | v = 0 | ||||
| for i in ints: | for i in ints: | ||||
| v |= 1 << i | v |= 1 << i | ||||
| @@ -69,25 +76,40 @@ def checkchanges(module): | |||||
| res = [] | res = [] | ||||
| for i in mods: | for i in mods: | ||||
| vlans = i.vlanconf.keys() | |||||
| switch = SNMPSwitch(i.host, i.community) | switch = SNMPSwitch(i.host, i.community) | ||||
| portmapping = switch.getportmapping() | portmapping = switch.getportmapping() | ||||
| invportmap = { y: x for x, y in portmapping.iteritems() } | invportmap = { y: x for x, y in portmapping.iteritems() } | ||||
| lufun = invportmap.__getitem__ | lufun = invportmap.__getitem__ | ||||
| portlist = getportlist(i._vlanconf, lufun) | |||||
| # get complete set of ports | |||||
| portlist = getportlist(i.vlanconf, lufun) | |||||
| ports = set(portmapping.iterkeys()) | ports = set(portmapping.iterkeys()) | ||||
| # make sure switch agrees w/ them all | |||||
| if ports != portlist: | if ports != portlist: | ||||
| raise ValueError('missing or extra ports found: %s' % | raise ValueError('missing or extra ports found: %s' % | ||||
| `ports.symmetric_difference(portlist)`) | `ports.symmetric_difference(portlist)`) | ||||
| pvidmap = getpvidmapping(i._vlanconf, lufun) | |||||
| # compare pvid | |||||
| pvidmap = getpvidmapping(i.vlanconf, lufun) | |||||
| switchpvid = switch.getpvid() | switchpvid = switch.getpvid() | ||||
| res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in | res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in | ||||
| pvidmap.iteritems() if switchpvid[idx] != vlan) | pvidmap.iteritems() if switchpvid[idx] != vlan) | ||||
| # compare egress & untagged | |||||
| switchegress = switch.getegress(*vlans) | |||||
| egress = getegress(i.vlanconf, lufun) | |||||
| switchuntagged = switch.getuntagged(*vlans) | |||||
| untagged = getuntagged(i.vlanconf, lufun) | |||||
| for i in vlans: | |||||
| if switchegress[i] != egress[i]: | |||||
| res.append(('setegress', i, egress[i], switchegress[i])) | |||||
| if switchuntagged[i] != untagged[i]: | |||||
| res.append(('setuntagged', i, untagged[i], switchuntagged[i])) | |||||
| return res | return res | ||||
| def getidxs(lst, lookupfun): | def getidxs(lst, lookupfun): | ||||
| @@ -109,7 +131,7 @@ def getpvidmapping(data, lookupfun): | |||||
| def getegress(data, lookupfun): | def getegress(data, lookupfun): | ||||
| r = {} | r = {} | ||||
| for id in data: | for id in data: | ||||
| r[id] = intstobits(*(getidxs(data[id]['u'], lookupfun) + | |||||
| r[id] = _intstobits(*(getidxs(data[id]['u'], lookupfun) + | |||||
| getidxs(data[id].get('t', []), lookupfun))) | getidxs(data[id].get('t', []), lookupfun))) | ||||
| return r | return r | ||||
| @@ -117,7 +139,7 @@ def getegress(data, lookupfun): | |||||
| def getuntagged(data, lookupfun): | def getuntagged(data, lookupfun): | ||||
| r = {} | r = {} | ||||
| for id in data: | for id in data: | ||||
| r[id] = intstobits(*getidxs(data[id]['u'], lookupfun)) | |||||
| r[id] = _intstobits(*getidxs(data[id]['u'], lookupfun)) | |||||
| return r | return r | ||||
| @@ -146,6 +168,25 @@ class SNMPSwitch(object): | |||||
| self._cd = CommunityData(community, mpModel=0) | self._cd = CommunityData(community, mpModel=0) | ||||
| self._targ = UdpTransportTarget((host, 161)) | self._targ = UdpTransportTarget((host, 161)) | ||||
| def _getmany(self, *oids): | |||||
| oids = [ ObjectIdentity(*oid) for oid in oids ] | |||||
| [ oid.resolveWithMib(_mvc) for oid in oids ] | |||||
| errorInd, errorStatus, errorIndex, varBinds = \ | |||||
| next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in oids))) | |||||
| if errorInd: # pragma: no cover | |||||
| raise ValueError(errorIndication) | |||||
| elif errorStatus: # pragma: no cover | |||||
| raise ValueError('%s at %s' % | |||||
| (errorStatus.prettyPrint(), errorIndex and | |||||
| varBinds[int(errorIndex)-1][0] or '?')) | |||||
| else: | |||||
| if len(varBinds) != len(oids): # pragma: no cover | |||||
| raise ValueError('too many return values') | |||||
| return varBinds | |||||
| def _get(self, oid): | def _get(self, oid): | ||||
| oid = ObjectIdentity(*oid) | oid = ObjectIdentity(*oid) | ||||
| oid.resolveWithMib(_mvc) | oid.resolveWithMib(_mvc) | ||||
| @@ -258,6 +299,20 @@ class SNMPSwitch(object): | |||||
| return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') } | return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') } | ||||
| def getegress(self, *vlans): | |||||
| r = { x[-1]: _octstrtobits(y) for x, y in | |||||
| self._getmany(*(('Q-BRIDGE-MIB', | |||||
| 'dot1qVlanStaticEgressPorts', x) for x in vlans)) } | |||||
| return r | |||||
| def getuntagged(self, *vlans): | |||||
| r = { x[-1]: _octstrtobits(y) for x, y in | |||||
| self._getmany(*(('Q-BRIDGE-MIB', | |||||
| 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) } | |||||
| return r | |||||
| class _TestMisc(unittest.TestCase): | class _TestMisc(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| import test_data | import test_data | ||||
| @@ -265,8 +320,14 @@ class _TestMisc(unittest.TestCase): | |||||
| self._test_data = test_data | self._test_data = test_data | ||||
| def test_intstobits(self): | def test_intstobits(self): | ||||
| self.assertEqual(intstobits(1, 5, 10), '1000100001') | |||||
| self.assertEqual(intstobits(3, 4, 9), '001100001') | |||||
| self.assertEqual(_intstobits(1, 5, 10), '1000100001') | |||||
| self.assertEqual(_intstobits(3, 4, 9), '001100001') | |||||
| def test_octstrtobits(self): | |||||
| self.assertEqual(_octstrtobits('\x00'), '0' * 8) | |||||
| self.assertEqual(_octstrtobits('\xff'), '1' * 8) | |||||
| self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4) | |||||
| self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4) | |||||
| def test_pvidegressuntagged(self): | def test_pvidegressuntagged(self): | ||||
| data = { | data = { | ||||
| @@ -315,11 +376,13 @@ class _TestMisc(unittest.TestCase): | |||||
| } | } | ||||
| self.assertEqual(getuntagged(data, lufun), checkuntagged) | self.assertEqual(getuntagged(data, lufun), checkuntagged) | ||||
| @unittest.skip('foo') | |||||
| #@unittest.skip('foo') | |||||
| @mock.patch('vlanmang.SNMPSwitch.getuntagged') | |||||
| @mock.patch('vlanmang.SNMPSwitch.getegress') | |||||
| @mock.patch('vlanmang.SNMPSwitch.getpvid') | @mock.patch('vlanmang.SNMPSwitch.getpvid') | ||||
| @mock.patch('vlanmang.SNMPSwitch.getportmapping') | @mock.patch('vlanmang.SNMPSwitch.getportmapping') | ||||
| @mock.patch('importlib.import_module') | @mock.patch('importlib.import_module') | ||||
| def test_checkchanges(self, imprt, portmapping, gpvid): | |||||
| def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged): | |||||
| # that import returns the test data | # that import returns the test data | ||||
| imprt.side_effect = itertools.repeat(self._test_data) | imprt.side_effect = itertools.repeat(self._test_data) | ||||
| @@ -346,30 +409,50 @@ class _TestMisc(unittest.TestCase): | |||||
| # delete the extra port | # delete the extra port | ||||
| del ports[31] | del ports[31] | ||||
| # that the egress data provided | |||||
| gegress.side_effect = [ { | |||||
| 1: '1' * 10, | |||||
| 5: '1' * 10, | |||||
| 283: '000000001111111111100110000001', | |||||
| } ] | |||||
| # that the untagged data provided | |||||
| guntagged.side_effect = [ { | |||||
| 1: '1' * 10, | |||||
| 5: '1' * 8, | |||||
| 283: '00000000111111111110011', | |||||
| } ] | |||||
| res = checkchanges('data') | res = checkchanges('data') | ||||
| validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \ | validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \ | ||||
| [ ('setpvid', 20, 1, 283), | [ ('setpvid', 20, 1, 283), | ||||
| ('setpvid', 21, 1, 283), | ('setpvid', 21, 1, 283), | ||||
| ('setpvid', 30, 1, 5), | ('setpvid', 30, 1, 5), | ||||
| ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', ''), | |||||
| ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', ''), | |||||
| ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', ''), | |||||
| ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), | |||||
| ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), | |||||
| ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10), | |||||
| ] | ] | ||||
| self.assertEqual(set(res), set(validres)) | self.assertEqual(set(res), set(validres)) | ||||
| _skipSwitchTests = True | |||||
| _skipSwitchTests = False | |||||
| class _TestSwitch(unittest.TestCase): | class _TestSwitch(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| # If we don't have it, pretend it's true for now and | |||||
| # we'll recheck it later | |||||
| model = 'GS108T smartSwitch' | |||||
| if getattr(self, 'switchmodel', model) != model or \ | |||||
| _skipSwitchTests: # pragma: no cover | |||||
| self.skipTest('Need a GS108T switch to run these tests') | |||||
| args = open('test.creds').read().split() | args = open('test.creds').read().split() | ||||
| self.switch = SNMPSwitch(*args) | self.switch = SNMPSwitch(*args) | ||||
| switchmodel = self.switch._get(('ENTITY-MIB', | |||||
| self.switchmodel = self.switch._get(('ENTITY-MIB', | |||||
| 'entPhysicalModelName', 1)) | 'entPhysicalModelName', 1)) | ||||
| if switchmodel != 'GS108T smartSwitch' or \ | |||||
| _skipSwitchTests: # pragma: no cover | |||||
| if self.switchmodel != model: # pragma: no cover | |||||
| self.skipTest('Need a GS108T switch to run these tests') | self.skipTest('Need a GS108T switch to run these tests') | ||||
| def test_misc(self): | def test_misc(self): | ||||
| @@ -387,6 +470,32 @@ class _TestSwitch(unittest.TestCase): | |||||
| self.assertEqual(switch.getportmapping(), resp) | self.assertEqual(switch.getportmapping(), resp) | ||||
| def test_egress(self): | |||||
| switch = self.switch | |||||
| egress = switch.getegress(1, 2, 3) | |||||
| checkegress = { | |||||
| 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23, | |||||
| 2: '0' * 8 * 5, | |||||
| 3: '0' * 8 * 5, | |||||
| } | |||||
| self.assertEqual(egress, checkegress) | |||||
| def test_untagged(self): | |||||
| switch = self.switch | |||||
| untagged = switch.getuntagged(1, 2, 3) | |||||
| checkuntagged = { | |||||
| 1: '1' * 8 * 5, | |||||
| 2: '1' * 8 * 5, | |||||
| 3: '1' * 8 * 5, | |||||
| } | |||||
| self.assertEqual(untagged, checkuntagged) | |||||
| def test_vlan(self): | def test_vlan(self): | ||||
| switch = self.switch | switch = self.switch | ||||