VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

335 lines
9.0 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import importlib
  7. import itertools
  8. import mock
  9. import random
  10. import unittest
  11. _mbuilder = MibBuilder()
  12. _mvc = MibViewController(_mbuilder)
  13. #import data
  14. # received packages
  15. # pvid: dot1qPvid
  16. #
  17. # tx packets:
  18. # dot1qVlanStaticEgressPorts
  19. # dot1qVlanStaticUntaggedPorts
  20. #
  21. # vlans:
  22. # dot1qVlanCurrentTable
  23. # lists ALL vlans, including baked in ones
  24. #
  25. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  26. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  27. # are still accessible via that oid
  28. #
  29. # LLDP:
  30. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  31. class SwitchConfig(object):
  32. def __init__(self, host, community, vlanconf):
  33. self._host = host
  34. self._community = community
  35. self._vlanconf = vlanconf
  36. @property
  37. def host(self):
  38. return self._host
  39. @property
  40. def community(self):
  41. return self._community
  42. @property
  43. def vlanconf(self):
  44. return self._vlanconf
  45. def checkchanges(module):
  46. mod = importlib.import_module(module)
  47. mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, SwitchConfig) ]
  48. for i in mods:
  49. switch = SNMPSwitch(i.host, i.community)
  50. portmapping = switch.getportmapping()
  51. invportmap = { y: x for x, y in portmapping.iteritems() }
  52. portlist = getportlist(i._vlanconf, invportmap.__getitem__)
  53. ports = set(portmapping.iterkeys())
  54. if ports != portlist:
  55. raise ValueError('missing or extra ports found: %s' % `ports.symmetric_difference(portlist)`)
  56. def getpvidmapping(data, lookupfun):
  57. '''Return a mapping from vlan based table to a port: vlan
  58. dictionary.'''
  59. res = []
  60. for id in data:
  61. for i in data[id]['u']:
  62. if isinstance(i, str):
  63. i = lookupfun(i)
  64. res.append((i, id))
  65. return dict(res)
  66. def getportlist(data, lookupfun):
  67. '''Return a set of all the ports indexes in data.'''
  68. res = set()
  69. for id in data:
  70. res.update(data[id]['u'])
  71. res.update(data[id].get('t', []))
  72. # filter out the strings
  73. strports = set(x for x in res if isinstance(x, str))
  74. res.update(lookupfun(x) for x in strports)
  75. res.difference_update(strports)
  76. return res
  77. class SNMPSwitch(object):
  78. '''A class for manipulating switches via standard SNMP MIBs.'''
  79. def __init__(self, host, community):
  80. self._eng = SnmpEngine()
  81. self._cd = CommunityData(community, mpModel=0)
  82. self._targ = UdpTransportTarget((host, 161))
  83. def _get(self, oid):
  84. oid = ObjectIdentity(*oid)
  85. oid.resolveWithMib(_mvc)
  86. errorInd, errorStatus, errorIndex, varBinds = \
  87. next(getCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid)))
  88. if errorInd: # pragma: no cover
  89. raise ValueError(errorIndication)
  90. elif errorStatus:
  91. raise ValueError('%s at %s' %
  92. (errorStatus.prettyPrint(), errorIndex and
  93. varBinds[int(errorIndex)-1][0] or '?'))
  94. else:
  95. if len(varBinds) != 1: # pragma: no cover
  96. raise ValueError('too many return values')
  97. varBind = varBinds[0]
  98. return varBind[1]
  99. def _set(self, oid, value):
  100. oid = ObjectIdentity(*oid)
  101. oid.resolveWithMib(_mvc)
  102. if isinstance(value, (int, long)):
  103. value = Integer(value)
  104. elif isinstance(value, str):
  105. value = OctetString(value)
  106. errorInd, errorStatus, errorIndex, varBinds = \
  107. next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value)))
  108. if errorInd: # pragma: no cover
  109. raise ValueError(errorIndication)
  110. elif errorStatus: # pragma: no cover
  111. raise ValueError('%s at %s' %
  112. (errorStatus.prettyPrint(), errorIndex and
  113. varBinds[int(errorIndex)-1][0] or '?'))
  114. else:
  115. for varBind in varBinds:
  116. if varBind[1] != value: # pragma: no cover
  117. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  118. def _walk(self, *oid):
  119. oid = ObjectIdentity(*oid)
  120. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  121. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  122. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  123. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  124. oid.resolveWithMib(_mvc)
  125. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  126. self._eng, self._cd, self._targ, ContextData(),
  127. ObjectType(oid),
  128. lexicographicMode=False):
  129. if errorInd: # pragma: no cover
  130. raise ValueError(errorIndication)
  131. elif errorStatus: # pragma: no cover
  132. raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?'))
  133. else:
  134. for varBind in varBinds:
  135. yield varBind
  136. def getportmapping(self):
  137. '''Return a port name mapping. Keys are the port index
  138. and the value is the name from the ifName entry.'''
  139. return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') }
  140. def findport(self, name):
  141. '''Look up a port name and return it's port index. This
  142. looks up via the ifName table in IF-MIB.'''
  143. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0]
  144. def getvlanname(self, vlan):
  145. '''Return the name for the vlan.'''
  146. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  147. return str(v).decode('utf-8')
  148. def createvlan(self, vlan, name):
  149. # createAndGo(4)
  150. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  151. int(vlan)), 4)
  152. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  153. name)
  154. def deletevlan(self, vlan):
  155. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  156. int(vlan)), 6) # destroy(6)
  157. def getvlans(self):
  158. '''Return an iterator with all the vlan ids.'''
  159. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus'))
  160. def staticvlans(self):
  161. '''Return an iterator of the staticly defined/configured
  162. vlans. This sometimes excludes special built in vlans,
  163. like vlan 1.'''
  164. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName'))
  165. def getpvid(self):
  166. '''Returns a dictionary w/ the interface index as the key,
  167. and the pvid of the interface.'''
  168. return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') }
  169. class _TestMisc(unittest.TestCase):
  170. def setUp(self):
  171. import test_data
  172. self._test_data = test_data
  173. def test_pvid(self):
  174. data = {
  175. 1: {
  176. 'u': [ 1, 5, 10 ] + range(13, 20)
  177. },
  178. 10: {
  179. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  180. },
  181. 13: {
  182. 'u': [ 4, 9 ],
  183. },
  184. }
  185. lookup = {
  186. 'lag2': 30
  187. }
  188. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  189. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  190. [ (30, 10) ]))
  191. # That a pvid mapping
  192. res = getpvidmapping(data, lookup.__getitem__)
  193. # is correct
  194. self.assertEqual(res, check)
  195. self.assertEqual(getportlist(data, lookup.__getitem__),
  196. set(xrange(1, 11)) | set(xrange(13, 20)) | set([30]))
  197. @mock.patch('vlanmang.SNMPSwitch.getportmapping')
  198. @mock.patch('importlib.import_module')
  199. def test_checkchanges(self, imprt, portmapping):
  200. def tmp(*args, **kwargs):
  201. return self._test_data
  202. imprt.side_effect = tmp
  203. ports = { x: 'g%d' % x for x in xrange(1, 24) }
  204. ports[30] = 'lag1'
  205. ports[31] = 'lag2'
  206. portmapping.side_effect = [ ports, ports ]
  207. self.assertRaises(ValueError, checkchanges, 'data')
  208. imprt.assert_called_with('data')
  209. portmapping.assert_called()
  210. del ports[31]
  211. checkchanges('data')
  212. _skipSwitchTests = False
  213. class _TestSwitch(unittest.TestCase):
  214. def setUp(self):
  215. args = open('test.creds').read().split()
  216. self.switch = SNMPSwitch(*args)
  217. switchmodel = self.switch._get(('ENTITY-MIB',
  218. 'entPhysicalModelName', 1))
  219. if switchmodel != 'GS108T smartSwitch' or \
  220. _skipSwitchTests: # pragma: no cover
  221. self.skipTest('Need a GS108T switch to run these tests')
  222. def test_misc(self):
  223. switch = self.switch
  224. self.assertEqual(switch.findport('g1'), 1)
  225. self.assertEqual(switch.findport('l1'), 14)
  226. def test_portnames(self):
  227. switch = self.switch
  228. resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
  229. resp.update({ 13: 'cpu' })
  230. resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
  231. self.assertEqual(switch.getportmapping(), resp)
  232. def test_vlan(self):
  233. switch = self.switch
  234. existingvlans = set(switch.getvlans())
  235. while True:
  236. testvlan = random.randint(1,4095)
  237. if testvlan not in existingvlans:
  238. break
  239. # Test that getting a non-existant vlans raises an exception
  240. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  241. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  242. pvidres = { x: 1 for x in xrange(1, 9) }
  243. pvidres.update({ x: 1 for x in xrange(14, 18) })
  244. self.assertEqual(switch.getpvid(), pvidres)
  245. testname = 'Sometestname'
  246. # Create test vlan
  247. switch.createvlan(testvlan, testname)
  248. try:
  249. # make sure the test vlan was created
  250. self.assertIn(testvlan, set(switch.staticvlans()))
  251. self.assertEqual(testname, switch.getvlanname(testvlan))
  252. finally:
  253. switch.deletevlan(testvlan)