|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
-
- from pysnmp.hlapi import *
- from pysnmp.smi.builder import MibBuilder
- from pysnmp.smi.view import MibViewController
-
- import importlib
- import itertools
- import mock
- import random
- import unittest
-
- _mbuilder = MibBuilder()
- _mvc = MibViewController(_mbuilder)
-
- #import data
-
- # received packages
- # pvid: dot1qPvid
- #
- # tx packets:
- # dot1qVlanStaticEgressPorts
- # dot1qVlanStaticUntaggedPorts
- #
- # vlans:
- # dot1qVlanCurrentTable
- # lists ALL vlans, including baked in ones
- #
- # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
- # skips over other vlans (only shows statics), the other vlans (1,2,3)
- # are still accessible via that oid
- #
- # LLDP:
- # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
-
- class SwitchConfig(object):
- def __init__(self, host, community, vlanconf):
- self._host = host
- self._community = community
- self._vlanconf = vlanconf
-
- @property
- def host(self):
- return self._host
-
- @property
- def community(self):
- return self._community
-
- @property
- def vlanconf(self):
- return self._vlanconf
-
- def checkchanges(module):
- mod = importlib.import_module(module)
- mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, SwitchConfig) ]
-
- for i in mods:
- switch = SNMPSwitch(i.host, i.community)
- portmapping = switch.getportmapping()
- invportmap = { y: x for x, y in portmapping.iteritems() }
-
- portlist = getportlist(i._vlanconf, invportmap.__getitem__)
-
- ports = set(portmapping.iterkeys())
-
- if ports != portlist:
- raise ValueError('missing or extra ports found: %s' % `ports.symmetric_difference(portlist)`)
-
- def getpvidmapping(data, lookupfun):
- '''Return a mapping from vlan based table to a port: vlan
- dictionary.'''
-
- res = []
- for id in data:
- for i in data[id]['u']:
- if isinstance(i, str):
- i = lookupfun(i)
- res.append((i, id))
-
- return dict(res)
-
- def getportlist(data, lookupfun):
- '''Return a set of all the ports indexes in data.'''
-
- res = set()
-
- for id in data:
- res.update(data[id]['u'])
- res.update(data[id].get('t', []))
-
- # filter out the strings
- strports = set(x for x in res if isinstance(x, str))
-
- res.update(lookupfun(x) for x in strports)
- res.difference_update(strports)
-
- return res
-
- class SNMPSwitch(object):
- '''A class for manipulating switches via standard SNMP MIBs.'''
-
- def __init__(self, host, community):
- self._eng = SnmpEngine()
- self._cd = CommunityData(community, mpModel=0)
- self._targ = UdpTransportTarget((host, 161))
-
- def _get(self, oid):
- oid = ObjectIdentity(*oid)
- oid.resolveWithMib(_mvc)
-
- errorInd, errorStatus, errorIndex, varBinds = \
- next(getCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid)))
-
- if errorInd: # pragma: no cover
- raise ValueError(errorIndication)
- elif errorStatus:
- raise ValueError('%s at %s' %
- (errorStatus.prettyPrint(), errorIndex and
- varBinds[int(errorIndex)-1][0] or '?'))
- else:
- if len(varBinds) != 1: # pragma: no cover
- raise ValueError('too many return values')
-
- varBind = varBinds[0]
- return varBind[1]
-
- def _set(self, oid, value):
- oid = ObjectIdentity(*oid)
- oid.resolveWithMib(_mvc)
-
- if isinstance(value, (int, long)):
- value = Integer(value)
- elif isinstance(value, str):
- value = OctetString(value)
-
- errorInd, errorStatus, errorIndex, varBinds = \
- next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value)))
-
- if errorInd: # pragma: no cover
- raise ValueError(errorIndication)
- elif errorStatus: # pragma: no cover
- raise ValueError('%s at %s' %
- (errorStatus.prettyPrint(), errorIndex and
- varBinds[int(errorIndex)-1][0] or '?'))
- else:
- for varBind in varBinds:
- if varBind[1] != value: # pragma: no cover
- raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
-
- def _walk(self, *oid):
- oid = ObjectIdentity(*oid)
- # XXX - keep these, this might stop working, no clue what managed to magically make things work
- # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
- # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
- #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
-
- oid.resolveWithMib(_mvc)
-
- for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
- self._eng, self._cd, self._targ, ContextData(),
- ObjectType(oid),
- lexicographicMode=False):
- if errorInd: # pragma: no cover
- raise ValueError(errorIndication)
- elif errorStatus: # pragma: no cover
- raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?'))
- else:
- for varBind in varBinds:
- yield varBind
-
- def getportmapping(self):
- '''Return a port name mapping. Keys are the port index
- and the value is the name from the ifName entry.'''
-
- return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') }
-
- def findport(self, name):
- '''Look up a port name and return it's port index. This
- looks up via the ifName table in IF-MIB.'''
-
- return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0]
-
- def getvlanname(self, vlan):
- '''Return the name for the vlan.'''
-
- v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
-
- return str(v).decode('utf-8')
-
- def createvlan(self, vlan, name):
- # createAndGo(4)
- self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
- int(vlan)), 4)
- self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
- name)
-
- def deletevlan(self, vlan):
- self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
- int(vlan)), 6) # destroy(6)
-
- def getvlans(self):
- '''Return an iterator with all the vlan ids.'''
-
- return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus'))
-
- def staticvlans(self):
- '''Return an iterator of the staticly defined/configured
- vlans. This sometimes excludes special built in vlans,
- like vlan 1.'''
-
- return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName'))
-
- def getpvid(self):
- '''Returns a dictionary w/ the interface index as the key,
- and the pvid of the interface.'''
-
- return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') }
-
- class _TestMisc(unittest.TestCase):
- def setUp(self):
- import test_data
-
- self._test_data = test_data
-
- def test_pvid(self):
- data = {
- 1: {
- 'u': [ 1, 5, 10 ] + range(13, 20)
- },
- 10: {
- 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
- },
- 13: {
- 'u': [ 4, 9 ],
- },
- }
- lookup = {
- 'lag2': 30
- }
-
- check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
- 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
- [ (30, 10) ]))
-
- # That a pvid mapping
- res = getpvidmapping(data, lookup.__getitem__)
-
- # is correct
- self.assertEqual(res, check)
-
- self.assertEqual(getportlist(data, lookup.__getitem__),
- set(xrange(1, 11)) | set(xrange(13, 20)) | set([30]))
-
- @mock.patch('vlanmang.SNMPSwitch.getportmapping')
- @mock.patch('importlib.import_module')
- def test_checkchanges(self, imprt, portmapping):
- def tmp(*args, **kwargs):
- return self._test_data
- imprt.side_effect = tmp
-
- ports = { x: 'g%d' % x for x in xrange(1, 24) }
- ports[30] = 'lag1'
- ports[31] = 'lag2'
- portmapping.side_effect = [ ports, ports ]
-
- self.assertRaises(ValueError, checkchanges, 'data')
-
- imprt.assert_called_with('data')
- portmapping.assert_called()
-
- del ports[31]
-
- checkchanges('data')
-
- _skipSwitchTests = False
-
- class _TestSwitch(unittest.TestCase):
- def setUp(self):
- args = open('test.creds').read().split()
- self.switch = SNMPSwitch(*args)
-
- switchmodel = self.switch._get(('ENTITY-MIB',
- 'entPhysicalModelName', 1))
- if switchmodel != 'GS108T smartSwitch' or \
- _skipSwitchTests: # pragma: no cover
- self.skipTest('Need a GS108T switch to run these tests')
-
- def test_misc(self):
- switch = self.switch
-
- self.assertEqual(switch.findport('g1'), 1)
- self.assertEqual(switch.findport('l1'), 14)
-
- def test_portnames(self):
- switch = self.switch
-
- resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
- resp.update({ 13: 'cpu' })
- resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
-
- self.assertEqual(switch.getportmapping(), resp)
-
- def test_vlan(self):
- switch = self.switch
-
- existingvlans = set(switch.getvlans())
-
- while True:
- testvlan = random.randint(1,4095)
- if testvlan not in existingvlans:
- break
-
- # Test that getting a non-existant vlans raises an exception
- self.assertRaises(ValueError, switch.getvlanname, testvlan)
-
- self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
-
- pvidres = { x: 1 for x in xrange(1, 9) }
- pvidres.update({ x: 1 for x in xrange(14, 18) })
- self.assertEqual(switch.getpvid(), pvidres)
-
- testname = 'Sometestname'
-
- # Create test vlan
- switch.createvlan(testvlan, testname)
- try:
- # make sure the test vlan was created
- self.assertIn(testvlan, set(switch.staticvlans()))
-
- self.assertEqual(testname, switch.getvlanname(testvlan))
- finally:
- switch.deletevlan(testvlan)
|