VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

246 lines
6.6 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import itertools
  7. import random
  8. import unittest
  9. _mbuilder = MibBuilder()
  10. _mvc = MibViewController(_mbuilder)
  11. #import data
  12. # received packages
  13. # pvid: dot1qPvid
  14. #
  15. # tx packets:
  16. # dot1qVlanStaticEgressPorts
  17. # dot1qVlanStaticUntaggedPorts
  18. #
  19. # vlans:
  20. # dot1qVlanCurrentTable
  21. # lists ALL vlans, including baked in ones
  22. #
  23. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  24. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  25. # are still accessible via that oid
  26. #
  27. # LLDP:
  28. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  29. def getpvidmapping(data, lookupfun):
  30. '''Return a mapping from vlan based table to a port: vlan
  31. dictionary.'''
  32. res = []
  33. for id in data:
  34. for i in data[id]['u']:
  35. if isinstance(i, str):
  36. i = lookupfun(i)
  37. res.append((i, id))
  38. return dict(res)
  39. def getportlist(data, lookupfun):
  40. '''Return a set of all the ports indexes in data.'''
  41. res = set()
  42. for id in data:
  43. res.update(data[id]['u'])
  44. res.update(data[id].get('t', []))
  45. # filter out the strings
  46. strports = set(x for x in res if isinstance(x, str))
  47. res.update(lookupfun(x) for x in strports)
  48. res.difference_update(strports)
  49. return res
  50. class SNMPSwitch(object):
  51. '''A class for manipulating switches via standard SNMP MIBs.'''
  52. def __init__(self, host, community):
  53. self._eng = SnmpEngine()
  54. self._cd = CommunityData(community, mpModel=0)
  55. self._targ = UdpTransportTarget((host, 161))
  56. def _get(self, oid):
  57. oid = ObjectIdentity(*oid)
  58. oid.resolveWithMib(_mvc)
  59. errorInd, errorStatus, errorIndex, varBinds = \
  60. next(getCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid)))
  61. if errorInd: # pragma: no cover
  62. raise ValueError(errorIndication)
  63. elif errorStatus:
  64. raise ValueError('%s at %s' %
  65. (errorStatus.prettyPrint(), errorIndex and
  66. varBinds[int(errorIndex)-1][0] or '?'))
  67. else:
  68. if len(varBinds) != 1: # pragma: no cover
  69. raise ValueError('too many return values')
  70. varBind = varBinds[0]
  71. return varBind[1]
  72. def _set(self, oid, value):
  73. oid = ObjectIdentity(*oid)
  74. oid.resolveWithMib(_mvc)
  75. if isinstance(value, (int, long)):
  76. value = Integer(value)
  77. elif isinstance(value, str):
  78. value = OctetString(value)
  79. errorInd, errorStatus, errorIndex, varBinds = \
  80. next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value)))
  81. if errorInd: # pragma: no cover
  82. raise ValueError(errorIndication)
  83. elif errorStatus: # pragma: no cover
  84. raise ValueError('%s at %s' %
  85. (errorStatus.prettyPrint(), errorIndex and
  86. varBinds[int(errorIndex)-1][0] or '?'))
  87. else:
  88. for varBind in varBinds:
  89. if varBind[1] != value: # pragma: no cover
  90. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  91. def _walk(self, *oid):
  92. oid = ObjectIdentity(*oid)
  93. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  94. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  95. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  96. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  97. oid.resolveWithMib(_mvc)
  98. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  99. self._eng, self._cd, self._targ, ContextData(),
  100. ObjectType(oid),
  101. lexicographicMode=False):
  102. if errorInd: # pragma: no cover
  103. raise ValueError(errorIndication)
  104. elif errorStatus: # pragma: no cover
  105. raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?'))
  106. else:
  107. for varBind in varBinds:
  108. yield varBind
  109. def findport(self, name):
  110. '''Look up a port name and return it's port index. This
  111. looks up via the ifName table in IF-MIB.'''
  112. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0]
  113. def getvlanname(self, vlan):
  114. '''Return the name for the vlan.'''
  115. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  116. return str(v).decode('utf-8')
  117. def createvlan(self, vlan, name):
  118. # createAndGo(4)
  119. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  120. int(vlan)), 4)
  121. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  122. name)
  123. def deletevlan(self, vlan):
  124. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  125. int(vlan)), 6) # destroy(6)
  126. def getvlans(self):
  127. '''Return all the vlans.'''
  128. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus'))
  129. def staticvlans(self):
  130. '''Return the staticly defined/configured vlans. This
  131. sometimes excludes special built in vlans, like vlan 1.'''
  132. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName'))
  133. class _TestMisc(unittest.TestCase):
  134. def test_pvid(self):
  135. data = {
  136. 1: {
  137. 'u': [ 1, 5, 10 ] + range(13, 20)
  138. },
  139. 10: {
  140. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  141. },
  142. 13: {
  143. 'u': [ 4, 9 ],
  144. },
  145. }
  146. lookup = {
  147. 'lag2': 30
  148. }
  149. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  150. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  151. [ (30, 10) ]))
  152. # That a pvid mapping
  153. res = getpvidmapping(data, lookup.__getitem__)
  154. # is correct
  155. self.assertEqual(res, check)
  156. self.assertEqual(getportlist(data, lookup.__getitem__),
  157. set(xrange(1, 11)) | set(xrange(13, 20)) | set([30]))
  158. _skipSwitchTests = True
  159. class _TestSwitch(unittest.TestCase):
  160. def setUp(self):
  161. args = open('test.creds').read().split()
  162. self.switch = SNMPSwitch(*args)
  163. def test_unpacktable(self):
  164. pass
  165. @unittest.skipIf(_skipSwitchTests, 'slow')
  166. def test_misc(self):
  167. switch = self.switch
  168. self.assertEqual(switch.findport('g1'), 1)
  169. self.assertEqual(switch.findport('l1'), 14)
  170. @unittest.skipIf(_skipSwitchTests, 'slow')
  171. def test_vlan(self):
  172. switch = self.switch
  173. existingvlans = set(switch.getvlans())
  174. while True:
  175. testvlan = random.randint(1,4095)
  176. if testvlan not in existingvlans:
  177. break
  178. # Test that getting a non-existant vlans raises an exception
  179. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  180. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  181. testname = 'Sometestname'
  182. # Create test vlan
  183. switch.createvlan(testvlan, testname)
  184. try:
  185. # make sure the test vlan was created
  186. self.assertIn(testvlan, set(switch.staticvlans()))
  187. self.assertEqual(testname, switch.getvlanname(testvlan))
  188. finally:
  189. switch.deletevlan(testvlan)