#!/usr/bin/env python # -*- coding: utf-8 -*- from pysnmp.hlapi import * from pysnmp.smi.builder import MibBuilder from pysnmp.smi.view import MibViewController import importlib import itertools import mock import random import unittest _mbuilder = MibBuilder() _mvc = MibViewController(_mbuilder) #import data # received packages # pvid: dot1qPvid # # tx packets: # dot1qVlanStaticEgressPorts # dot1qVlanStaticUntaggedPorts # # vlans: # dot1qVlanCurrentTable # lists ALL vlans, including baked in ones # # note that even though an snmpwalk of dot1qVlanStaticEgressPorts # skips over other vlans (only shows statics), the other vlans (1,2,3) # are still accessible via that oid # # LLDP: # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable class SwitchConfig(object): '''This is a simple object to store switch configuration for the checkchanges() function. host -- The host of the switch you are maintaining configuration of. authargs -- This is a dictionary of kwargs to pass to SNMPSwitch. If SNMPv1 (insecure) is used, pass dict(community='communitystr'). Example for SNMPv3 where the key for both authentication and encryption are the same: dict(username='username', authKey=key, privKey=key) vlanconf -- This is a dictionary w/ vlans as the key. Each value has a dictionary that contains keys, 'u' or 't', each of which contains the port that traffic should be sent untagged ('u') or tagged ('t'). Note that the Pvid (vlan of traffic that is received when untagged), is set to match the 'u' definition. The port is either an integer, which maps directly to the switch's index number, or it can be a string, which will be looked up via the IF-MIB::ifName table. Example specifies that VLANs 1 and 2 will be transmitted as tagged packets on the port named 'lag1'. That ports 1, 2, 3, 4 and 5 will be untagged on VLAN 1, and ports 6, 7, 8 and 9 will be untagged on VLAN 2: { 1: { 'u': [ 1, 2, 3, 4, 5 ], 't': [ 'lag1' ], }, 2: { 'u': [ 6, 7, 8, 9 ], 't': [ 'lag1' ], }, } ignports -- Ports that will be ignored and not required to be configured. List any ports that will not be active here, such as any unused lag ports. ''' def __init__(self, host, authargs, vlanconf, ignports): self._host = host self._authargs = authargs self._vlanconf = vlanconf self._ignports = ignports @property def host(self): return self._host @property def authargs(self): return self._authargs @property def vlanconf(self): return self._vlanconf @property def ignports(self): return self._ignports def getportlist(self, lookupfun): '''Return a set of all the ports indexes in data. This includes, both vlanconf and ignports. Any ports using names will be resolved by being passed to the provided lookupfun.''' res = [] for id in self._vlanconf: res.extend(self._vlanconf[id].get('u', [])) res.extend(self._vlanconf[id].get('t', [])) # add in the ignore ports res.extend(self.ignports) # eliminate dups so that lookupfun isn't called as often res = set(res) return set(getidxs(res, lookupfun)) def _octstrtobits(os): '''Convert a string into a list of bits. Easier to figure out what ports are set.''' num = 1 # leading 1 to make sure leading zeros are not stripped for i in str(os): num = (num << 8) | ord(i) return bin(num)[3:] def _intstobits(*ints): '''Convert the int args to a string of bits in the expected format that SNMP expects for them. The results will be a string of '1's and '0's where the first one represents 1, and second one representing 2 and so on.''' v = 0 for i in ints: v |= 1 << i r = list(bin(v)[2:-1]) r.reverse() return ''.join(r) def _cmpbits(a, b): '''Compare two strings of bits to make sure they are equal. Trailing 0's are ignored.''' try: last1a = a.rindex('1') except ValueError: last1a = -1 a = '' try: last1b = b.rindex('1') except ValueError: last1b = -1 b = '' if last1a != -1: a = a[:last1a + 1] if last1b != -1: b = b[:last1b + 1] return a == b import vlanmang def checkchanges(module): '''Function to check for any differences between the switch, and the configured state. The parameter module is a string to the name of a python module. It will be imported, and any names that reference a vlanmang.SwitchConfig class will be validate that the configuration matches. If it does not, the returned list will contain a set of tuples, each one containing (verb, arg1, arg2, switcharg2). verb is what needs to be changed. arg1 is either the port (for setting Pvid), or the VLAN that needs to be configured. arg2 is what it needs to be set to. switcharg2 is what the switch is currently configured to, so that you can easily see what the effect of the configuration change is. ''' mod = importlib.import_module(module) mods = [ (k, v) for k, v in mod.__dict__.iteritems() if isinstance(v, vlanmang.SwitchConfig) ] res = [] for name, i in mods: #print 'probing %s' % `name` vlans = i.vlanconf.keys() switch = SNMPSwitch(i.host, **i.authargs) switchpvid = switch.getpvid() portmapping = switch.getportmapping() invportmap = { y: x for x, y in portmapping.iteritems() } lufun = invportmap.__getitem__ # get complete set of ports portlist = i.getportlist(lufun) ports = set(portmapping.iterkeys()) # make sure switch agrees w/ them all if ports != portlist: raise ValueError('missing or extra ports found: %s' % `ports.symmetric_difference(portlist)`) # compare pvid pvidmap = getpvidmapping(i.vlanconf, lufun) switchpvid = switch.getpvid() res.extend((switch, name, 'setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in pvidmap.iteritems() if switchpvid[idx] != vlan) # compare egress & untagged switchegress = switch.getegress(*vlans) egress = getegress(i.vlanconf, lufun) switchuntagged = switch.getuntagged(*vlans) untagged = getuntagged(i.vlanconf, lufun) for i in vlans: if not _cmpbits(switchegress[i], egress[i]): res.append((switch, name, 'setegress', i, egress[i], switchegress[i])) if not _cmpbits(switchuntagged[i], untagged[i]): res.append((switch, name, 'setuntagged', i, untagged[i], switchuntagged[i])) return res def getidxs(lst, lookupfun): '''Take a list of ports, and if any are a string, replace them w/ the value returned by lookupfun(s). Note that duplicates are not detected or removed, both in the original list, and the values returned by the lookup function may duplicate other values in the list.''' return [ lookupfun(i) if isinstance(i, str) else i for i in lst ] def getpvidmapping(data, lookupfun): '''Return a mapping from vlan based table to a port: vlan dictionary. This only looks at that untagged part of the vlan configuration, and is used for finding what a port's Pvid should be.''' res = [] for id in data: for i in data[id].get('u', []): if isinstance(i, str): i = lookupfun(i) res.append((i, id)) return dict(res) def getegress(data, lookupfun): '''Return a dictionary, keyed by VLAN id with a bit string of ports that need to be enabled for egress. This include both tagged and untagged traffic.''' r = {} for id in data: r[id] = _intstobits(*(getidxs(data[id].get('u', []), lookupfun) + getidxs(data[id].get('t', []), lookupfun))) return r def getuntagged(data, lookupfun): '''Return a dictionary, keyed by VLAN id with a bit string of ports that need to be enabled for untagged egress.''' r = {} for id in data: r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun)) return r class SNMPSwitch(object): '''A class for manipulating switches via standard SNMP MIBs.''' def __init__(self, host, community=None, username=None, authKey=None, authProtocol=usmHMACSHAAuthProtocol, privKey=None, privProtocol=None): '''Create a instance to read data and program a switch via SNMP. Args: host -- Host name or IP address of the switch. community -- If using SNMPv1 (not recommended, insecure), this is the community name to authenticate. username -- The username to authenticate when using SNMPv3. This varies, some cases it can be programmed and a specific user is created, in other cases, it is hard coded to a user like 'admin'. authKey -- This is the key string used to authenticate the SNMP requests. authProtocol -- This is protocol used to authenticate the SNMP requests. It is one of the values passed to authProtocol of pysnmp's UsmUserData as documented at: http://snmplabs.com/pysnmp/docs/api-reference.html#pysnmp.hlapi.UsmUserData privKey -- This is the key string used to encrypt the SNMP requests. privProtocol -- This is protocol used to encrypt the SNMP requests. It is one of the values passed to privProtocol of pysnmp's UsmUserData as documented at: http://snmplabs.com/pysnmp/docs/api-reference.html#pysnmp.hlapi.UsmUserData ''' if community is not None and username is not None: raise ValueError('only one of community and username is allowed to be specified') self._eng = SnmpEngine() if community is not None: self._auth = CommunityData(community, mpModel=0) else: args = (username, authKey, ) kwargs = { 'authProtocol': authProtocol } if privKey is not None: args += (privKey,) kwargs['privProtocol'] = \ usmAesCfb256Protocol if privProtocol is \ None else privProtocol self._auth = UsmUserData(*args, **kwargs) self._targ = UdpTransportTarget((host, 161)) def __repr__(self): # pragma: no cover return '' % (`self._auth`, `self._targ`) def _getmany(self, *oids): woids = [ ObjectIdentity(*oid) for oid in oids ] [ oid.resolveWithMib(_mvc) for oid in woids ] errorInd, errorStatus, errorIndex, varBinds = \ next(getCmd(self._eng, self._auth, self._targ, ContextData(), *(ObjectType(oid) for oid in woids))) if errorInd: # pragma: no cover raise ValueError(errorInd) elif errorStatus: if str(errorStatus) == 'tooBig' and len(oids) > 1: # split the request in two pivot = len(oids) / 2 a = self._getmany(*oids[:pivot]) b = self._getmany(*oids[pivot:]) return a + b raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: if len(varBinds) != len(oids): # pragma: no cover raise ValueError('too many return values') return varBinds def _get(self, oid): varBinds = self._getmany(oid) varBind = varBinds[0] return varBind[1] def _set(self, oid, value): oid = ObjectIdentity(*oid) oid.resolveWithMib(_mvc) if isinstance(value, (int, long)): value = Integer(value) elif isinstance(value, str): value = OctetString(value) errorInd, errorStatus, errorIndex, varBinds = \ next(setCmd(self._eng, self._auth, self._targ, ContextData(), ObjectType(oid, value))) if errorInd: # pragma: no cover raise ValueError(errorInd) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: for varBind in varBinds: if varBind[1] != value: # pragma: no cover raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind])) def _walk(self, *oid): oid = ObjectIdentity(*oid) # XXX - keep these, this might stop working, no clue what managed to magically make things work # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs') oid.resolveWithMib(_mvc) for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd( self._eng, self._auth, self._targ, ContextData(), ObjectType(oid), lexicographicMode=False): if errorInd: # pragma: no cover raise ValueError(errorInd) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: for varBind in varBinds: yield varBind def getportmapping(self): '''Return a port name mapping. Keys are the port index and the value is the name from the IF-MIB::ifName entry.''' return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') } def findport(self, name): '''Look up a port name and return it's port index. This looks up via the ifName table in IF-MIB.''' return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0] def getvlanname(self, vlan): '''Return the name for the vlan. This returns the value in Q-BRIDGE-MIB:dot1qVlanStaticName.''' v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan)) return str(v).decode('utf-8') def createvlan(self, vlan, name): # createAndGo(4) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus', int(vlan)), 4) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)), name) def deletevlan(self, vlan): self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus', int(vlan)), 6) # destroy(6) def getvlans(self): '''Return an iterator with all the vlan ids.''' return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus')) def staticvlans(self): '''Return an iterator of the staticly defined/configured vlans. This sometimes excludes special built in vlans, like vlan 1.''' return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName')) def getpvid(self): '''Returns a dictionary w/ the interface index as the key, and the pvid of the interface.''' return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') } def setpvid(self, port, vlan): '''Set the port's Pvid to vlan. This means that any packet received by the port that is untagged, will be routed the the vlan.''' self._set(('Q-BRIDGE-MIB', 'dot1qPvid', int(port)), Gauge32(vlan)) def getegress(self, *vlans): '''Get a dictionary keyed by the specified VLANs, where each value is a bit string that preresents what ports that particular VLAN will be transmitted on.''' r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts', x) for x in vlans)) } return r def setegress(self, vlan, ports): '''Set the ports which the specified VLAN will have packets transmitted as either tagged, if unset in untagged, or untagged, if set in untagged, to bit bit string specified by ports.''' value = OctetString.fromBinaryString(ports) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts', int(vlan)), value) def getuntagged(self, *vlans): '''Get a dictionary keyed by the specified VLANs, where each value is a bit string that preresents what ports that particular VLAN will be transmitted on as an untagged packet.''' r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) } return r def setuntagged(self, vlan, ports): '''Set the ports which the specified VLAN will have packets transmitted as untagged to the bit string specified by ports.''' value = OctetString.fromBinaryString(ports) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts', int(vlan)), value) def main(): import pprint import sys changes = checkchanges('data') if not changes: print 'No changes to apply.' sys.exit(0) pprint.pprint([ x[1:] for x in changes ]) res = raw_input('Apply the changes? (type yes to apply): ') if res != 'yes': print 'not applying changes.' sys.exit(1) print 'applying...' failed = [] prevname = None for switch, name, verb, arg1, arg2, oldarg in changes: if prevname != name: print 'Configuring switch %s...' % `name` prevname = name print '%s: %s %s' % (verb, arg1, `arg2`) try: fun = getattr(switch, verb) fun(arg1, arg2) pass except Exception as e: print 'failed' failed.append((verb, arg1, arg2, e)) if failed: print '%d failed to apply, they are:' % len(failed) for verb, arg1, arg2, e in failed: print '%s: %s %s: %s' % (verb, arg1, arg2, `e`) if __name__ == '__main__': # pragma: no cover main() class _TestMisc(unittest.TestCase): def setUp(self): import test_data self.skipTest('foo') self._test_data = test_data def test_intstobits(self): self.assertEqual(_intstobits(1, 5, 10), '1000100001') self.assertEqual(_intstobits(3, 4, 9), '001100001') def test_octstrtobits(self): self.assertEqual(_octstrtobits('\x00'), '0' * 8) self.assertEqual(_octstrtobits('\xff'), '1' * 8) self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4) self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4) def test_cmpbits(self): self.assertTrue(_cmpbits('111000', '111')) self.assertTrue(_cmpbits('000111000', '000111')) self.assertTrue(_cmpbits('11', '11')) self.assertTrue(_cmpbits('0', '000')) self.assertFalse(_cmpbits('0011', '11')) self.assertFalse(_cmpbits('11', '0011')) self.assertFalse(_cmpbits('10', '000')) self.assertFalse(_cmpbits('0', '1000')) self.assertFalse(_cmpbits('00010', '000')) self.assertFalse(_cmpbits('0', '001000')) def test_pvidegressuntagged(self): data = { 1: { 'u': [ 1, 5, 10 ] + range(13, 20), 't': [ 'lag2', 6, 7 ], }, 10: { 'u': [ 2, 3, 6, 7, 8, 'lag2' ], }, 13: { 'u': [ 4, 9 ], 't': [ 'lag2', 6, 7 ], }, 14: { 't': [ 'lag2' ], }, } swconf = SwitchConfig('', {}, data, [ 'lag3' ]) lookup = { 'lag2': 30, 'lag3': 31, } lufun = lookup.__getitem__ check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10, 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13), [ (30, 10) ])) # That a pvid mapping res = getpvidmapping(data, lufun) # is correct self.assertEqual(res, check) self.assertEqual(swconf.getportlist(lufun), set(xrange(1, 11)) | set(xrange(13, 20)) | set(lookup.values())) checkegress = { 1: '1000111001001111111' + '0' * (30 - 20) + '1', 10: '01100111' + '0' * (30 - 9) + '1', 13: '000101101' + '0' * (30 - 10) + '1', 14: '0' * (30 - 1) + '1', } self.assertEqual(getegress(data, lufun), checkegress) checkuntagged = { 1: '1000100001001111111', 10: '01100111' + '0' * (30 - 9) + '1', 13: '000100001', 14: '', } self.assertEqual(getuntagged(data, lufun), checkuntagged) @mock.patch('vlanmang.CommunityData') @mock.patch('vlanmang.getCmd') def test_v1auth(self, gc, cd): # That the CommunityData class returns an object cdobj = object() cd.side_effect = [ cdobj ] # That a switch passed a community string commstr = 'foobar' switch = SNMPSwitch(None, community=commstr) # That getCmd returns a valid object vb = [ [ None, None ] ] gc.side_effect = [ iter([[ None ] * 3 + [ vb ] ]) ] r = switch.getvlanname(1) # That getCmd was called gc.assert_called() # with the correct auth object calledcd = gc.call_args.args[1] self.assertIs(calledcd, cdobj) # and that CommunityData was called w/ the correct args cd.assert_called_with(commstr, mpModel=0) def test_badauth(self): # that when both community and username are provided # it raises a ValueError self.assertRaises(ValueError, SNMPSwitch, 'somehost', community='foo', username='bar') @mock.patch('vlanmang.UsmUserData') @mock.patch('vlanmang.getCmd') def test_v3auth(self, gc, uud): # That the UsmUserData class returns an object uudobj = object() uud.side_effect = [ uudobj ] * 5 # That a switch passed v3 auth data username = 'someuser' authKey = 'authKey' switch = SNMPSwitch(None, username=username, authKey=authKey) # That getCmd returns a valid object vb = [ [ None, None ] ] gc.side_effect = [ iter([[ None ] * 3 + [ vb ] ]) ] * 10 r = switch.getvlanname(1) # That getCmd was called gc.assert_called() # with the correct auth object calleduud = gc.call_args.args[1] self.assertIs(calleduud, uudobj) # and that UsmUserData was called w/ the correct args uud.assert_called_with(username, authKey, authProtocol=usmHMACSHAAuthProtocol) # Reset the usm data uud.reset_mock() # that it can be called with a privKey privKey = 'privKey' switch = SNMPSwitch(None, username=username, authKey=authKey, privKey=privKey) # and that UsmUserData was called w/ the correct args uud.assert_called_with(username, authKey, privKey, authProtocol=usmHMACSHAAuthProtocol, privProtocol=usmAesCfb256Protocol) # Reset the usm data uud.reset_mock() # that it can be called with an alternate privProtocol switch = SNMPSwitch(None, username=username, authKey=authKey, privKey=privKey, privProtocol=usmDESPrivProtocol) # and that UsmUserData was called w/ the correct args uud.assert_called_with(username, authKey, privKey, authProtocol=usmHMACSHAAuthProtocol, privProtocol=usmDESPrivProtocol) # Reset the usm data uud.reset_mock() # that it can be called with an alternate authProtocol switch = SNMPSwitch(None, username=username, authKey=authKey, authProtocol=usmHMACMD5AuthProtocol, privKey=privKey, privProtocol=usmDESPrivProtocol) # and that UsmUserData was called w/ the correct args uud.assert_called_with(username, authKey, privKey, authProtocol=usmHMACMD5AuthProtocol, privProtocol=usmDESPrivProtocol) #@unittest.skip('foo') @mock.patch('vlanmang.SNMPSwitch.getuntagged') @mock.patch('vlanmang.SNMPSwitch.getegress') @mock.patch('vlanmang.SNMPSwitch.getpvid') @mock.patch('vlanmang.SNMPSwitch.getportmapping') @mock.patch('importlib.import_module') def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged): # that import returns the test data imprt.side_effect = itertools.repeat(self._test_data) # that getportmapping returns the following dict ports = { x: 'g%d' % x for x in xrange(1, 24) } ports[30] = 'lag1' ports[31] = 'lag2' ports[32] = 'lag3' portmapping.side_effect = itertools.repeat(ports) # that the switch's pvid returns spvid = { x: 283 for x in xrange(1, 24) } spvid[30] = 5 gpvid.side_effect = itertools.repeat(spvid) # the the extra port is caught self.assertRaises(ValueError, checkchanges, 'data') # that the functions were called imprt.assert_called_with('data') portmapping.assert_called() # XXX - check that an ignore statement is honored # delete the extra port del ports[32] # that the egress data provided gegress.side_effect = [ { 1: '1' * 10, 5: '1' * 10, 283: '00000000111111111110011000000100000', } ] # that the untagged data provided guntagged.side_effect = [ { 1: '1' * 10, 5: '1' * 8 + '0' * 10, 283: '00000000111111111110011', } ] res = checkchanges('data') # Make sure that the first one are all instances of SNMPSwitch # XXX make sure args for them are correct. self.assertTrue(all(isinstance(x[0], SNMPSwitch) for x in res)) # Make sure that the name provided is correct self.assertTrue(all(x[1] == 'distswitch' for x in res)) res = [ x[2:] for x in res ] validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \ [ ('setpvid', 20, 1, 283), ('setpvid', 21, 1, 283), ('setpvid', 30, 1, 5), ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10), ] self.assertEqual(set(res), set(validres)) class _TestSNMPSwitch(unittest.TestCase): def setUp(self): self.skipTest('foo') @mock.patch('vlanmang.SNMPSwitch._getmany') def test_get(self, gm): # that a switch switch = SNMPSwitch(None, community=None) # when _getmany returns this structure retval = object() gm.side_effect = [[[ None, retval ]]] arg = object() # will return the correct value self.assertIs(switch._get(arg), retval) # and call _getmany w/ the correct arg gm.assert_called_with(arg) @mock.patch('pysnmp.hlapi.ContextData') @mock.patch('vlanmang.getCmd') def test_getmany(self, gc, cd): # that a switch switch = SNMPSwitch(None, community=None) lookup = { x: chr(x) for x in xrange(1, 10) } # when getCmd returns tooBig when too many oids are asked for def custgetcmd(eng, cd, targ, contextdata, *oids): # induce a too big error if len(oids) > 3: res = ( None, 'tooBig', None, None ) else: #import pdb; pdb.set_trace() [ oid.resolveWithMib(_mvc) for oid in oids ] oids = [ ObjectType(x[0], OctetString(lookup[x[0][-1]])) for x in oids ] [ oid.resolveWithMib(_mvc) for oid in oids ] res = ( None, None, None, oids ) return iter([res]) gc.side_effect = custgetcmd #import pdb; pdb.set_trace() res = switch.getegress(*xrange(1, 10)) # will still return the complete set of results self.assertEqual(res, { x: _octstrtobits(lookup[x]) for x in xrange(1, 10) }) _skipSwitchTests = False class _TestSwitch(unittest.TestCase): def setUp(self): # If we don't have it, pretend it's true for now and # we'll recheck it later model = 'GS108T smartSwitch' if getattr(self, 'switchmodel', model) != model or \ _skipSwitchTests: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') host, authkey, privkey = open('test.creds').read().split() self.switch = SNMPSwitch(host, authKey=authkey, privKey=privkey, privProtocol=usmDESPrivProtocol) self.switchmodel = self.switch._get(('ENTITY-MIB', 'entPhysicalModelName', 1)) if self.switchmodel != model: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') def test_misc(self): switch = self.switch self.assertEqual(switch.findport('g1'), 1) self.assertEqual(switch.findport('l1'), 14) def test_portnames(self): switch = self.switch resp = dict((x, 'g%d' % x) for x in xrange(1, 9)) resp.update({ 13: 'cpu' }) resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18)) self.assertEqual(switch.getportmapping(), resp) def test_egress(self): switch = self.switch egress = switch.getegress(1, 2, 3) checkegress = { 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23, 2: '0' * 8 * 5, 3: '0' * 8 * 5, } self.assertEqual(egress, checkegress) def test_untagged(self): switch = self.switch untagged = switch.getuntagged(1, 2, 3) checkuntagged = { 1: '1' * 8 * 5, 2: '1' * 8 * 5, 3: '1' * 8 * 5, } self.assertEqual(untagged, checkuntagged) def test_vlan(self): switch = self.switch existingvlans = set(switch.getvlans()) while True: testvlan = random.randint(1,4095) if testvlan not in existingvlans: break # Test that getting a non-existant vlans raises an exception self.assertRaises(ValueError, switch.getvlanname, testvlan) self.assertTrue(set(switch.staticvlans()).issubset(existingvlans)) pvidres = { x: 1 for x in xrange(1, 9) } pvidres.update({ x: 1 for x in xrange(14, 18) }) self.assertEqual(switch.getpvid(), pvidres) testname = 'Sometestname' # Create test vlan switch.createvlan(testvlan, testname) testport = None try: # make sure the test vlan was created self.assertIn(testvlan, set(switch.staticvlans())) self.assertEqual(testname, switch.getvlanname(testvlan)) switch.setegress(testvlan, '00100') pvidmap = switch.getpvid() testport = 3 egressports = switch.getegress(testvlan) self.assertEqual(egressports[testvlan], '00100000' + '0' * 8 * 4) switch.setuntagged(testvlan, '00100') untaggedports = switch.getuntagged(testvlan) self.assertEqual(untaggedports[testvlan], '00100000' + '0' * 8 * 4) switch.setpvid(testport, testvlan) self.assertEqual(switch.getpvid()[testport], testvlan) finally: if testport: switch.setpvid(testport, pvidmap[3]) switch.deletevlan(testvlan)