VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

568 lines
15 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import importlib
  7. import itertools
  8. import mock
  9. import random
  10. import unittest
  11. _mbuilder = MibBuilder()
  12. _mvc = MibViewController(_mbuilder)
  13. #import data
  14. # received packages
  15. # pvid: dot1qPvid
  16. #
  17. # tx packets:
  18. # dot1qVlanStaticEgressPorts
  19. # dot1qVlanStaticUntaggedPorts
  20. #
  21. # vlans:
  22. # dot1qVlanCurrentTable
  23. # lists ALL vlans, including baked in ones
  24. #
  25. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  26. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  27. # are still accessible via that oid
  28. #
  29. # LLDP:
  30. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  31. class SwitchConfig(object):
  32. def __init__(self, host, community, vlanconf, ignports):
  33. self._host = host
  34. self._community = community
  35. self._vlanconf = vlanconf
  36. self._ignports = ignports
  37. @property
  38. def host(self):
  39. return self._host
  40. @property
  41. def community(self):
  42. return self._community
  43. @property
  44. def vlanconf(self):
  45. return self._vlanconf
  46. @property
  47. def ignports(self):
  48. return self._ignports
  49. def getportlist(self, lookupfun):
  50. '''Return a set of all the ports indexes in data.'''
  51. res = set()
  52. for id in self._vlanconf:
  53. res.update(self._vlanconf[id].get('u', []))
  54. res.update(self._vlanconf[id].get('t', []))
  55. # add in the ignore ports
  56. res.update(self.ignports)
  57. # filter out the strings
  58. strports = set(x for x in res if isinstance(x, str))
  59. res.update(lookupfun(x) for x in strports)
  60. res.difference_update(strports)
  61. return res
  62. def _octstrtobits(os):
  63. num = 1
  64. for i in str(os):
  65. num = (num << 8) | ord(i)
  66. return bin(num)[3:]
  67. def _intstobits(*ints):
  68. v = 0
  69. for i in ints:
  70. v |= 1 << i
  71. r = list(bin(v)[2:-1])
  72. r.reverse()
  73. return ''.join(r)
  74. def _cmpbits(a, b):
  75. try:
  76. last1a = a.rindex('1')
  77. except ValueError:
  78. last1a = -1
  79. a = ''
  80. try:
  81. last1b = b.rindex('1')
  82. except ValueError:
  83. last1b = -1
  84. b = ''
  85. if last1a != -1:
  86. a = a[:last1a + 1]
  87. if last1b != -1:
  88. b = b[:last1b + 1]
  89. return a == b
  90. import vlanmang
  91. def checkchanges(module):
  92. mod = importlib.import_module(module)
  93. mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, vlanmang.SwitchConfig) ]
  94. res = []
  95. for i in mods:
  96. vlans = i.vlanconf.keys()
  97. switch = SNMPSwitch(i.host, i.community)
  98. portmapping = switch.getportmapping()
  99. invportmap = { y: x for x, y in portmapping.iteritems() }
  100. lufun = invportmap.__getitem__
  101. # get complete set of ports
  102. portlist = i.getportlist(lufun)
  103. ports = set(portmapping.iterkeys())
  104. # make sure switch agrees w/ them all
  105. if ports != portlist:
  106. raise ValueError('missing or extra ports found: %s' %
  107. `ports.symmetric_difference(portlist)`)
  108. # compare pvid
  109. pvidmap = getpvidmapping(i.vlanconf, lufun)
  110. switchpvid = switch.getpvid()
  111. res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in
  112. pvidmap.iteritems() if switchpvid[idx] != vlan)
  113. # compare egress & untagged
  114. switchegress = switch.getegress(*vlans)
  115. egress = getegress(i.vlanconf, lufun)
  116. switchuntagged = switch.getuntagged(*vlans)
  117. untagged = getuntagged(i.vlanconf, lufun)
  118. for i in vlans:
  119. if not _cmpbits(switchegress[i], egress[i]):
  120. res.append(('setegress', i, egress[i], switchegress[i]))
  121. if not _cmpbits(switchuntagged[i], untagged[i]):
  122. res.append(('setuntagged', i, untagged[i], switchuntagged[i]))
  123. return res
  124. def getidxs(lst, lookupfun):
  125. return [ lookupfun(i) if isinstance(i, str) else i for i in lst ]
  126. def getpvidmapping(data, lookupfun):
  127. '''Return a mapping from vlan based table to a port: vlan
  128. dictionary.'''
  129. res = []
  130. for id in data:
  131. for i in data[id].get('u', []):
  132. if isinstance(i, str):
  133. i = lookupfun(i)
  134. res.append((i, id))
  135. return dict(res)
  136. def getegress(data, lookupfun):
  137. r = {}
  138. for id in data:
  139. r[id] = _intstobits(*(getidxs(data[id].get('u', []),
  140. lookupfun) + getidxs(data[id].get('t', []), lookupfun)))
  141. return r
  142. def getuntagged(data, lookupfun):
  143. r = {}
  144. for id in data:
  145. r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun))
  146. return r
  147. class SNMPSwitch(object):
  148. '''A class for manipulating switches via standard SNMP MIBs.'''
  149. def __init__(self, host, community):
  150. self._eng = SnmpEngine()
  151. self._cd = CommunityData(community, mpModel=0)
  152. self._targ = UdpTransportTarget((host, 161))
  153. def _getmany(self, *oids):
  154. oids = [ ObjectIdentity(*oid) for oid in oids ]
  155. [ oid.resolveWithMib(_mvc) for oid in oids ]
  156. errorInd, errorStatus, errorIndex, varBinds = \
  157. next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in oids)))
  158. if errorInd: # pragma: no cover
  159. raise ValueError(errorIndication)
  160. elif errorStatus: # pragma: no cover
  161. raise ValueError('%s at %s' %
  162. (errorStatus.prettyPrint(), errorIndex and
  163. varBinds[int(errorIndex)-1][0] or '?'))
  164. else:
  165. if len(varBinds) != len(oids): # pragma: no cover
  166. raise ValueError('too many return values')
  167. return varBinds
  168. def _get(self, oid):
  169. varBinds = self._getmany(oid)
  170. varBind = varBinds[0]
  171. return varBind[1]
  172. def _set(self, oid, value):
  173. oid = ObjectIdentity(*oid)
  174. oid.resolveWithMib(_mvc)
  175. if isinstance(value, (int, long)):
  176. value = Integer(value)
  177. elif isinstance(value, str):
  178. value = OctetString(value)
  179. errorInd, errorStatus, errorIndex, varBinds = \
  180. next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value)))
  181. if errorInd: # pragma: no cover
  182. raise ValueError(errorIndication)
  183. elif errorStatus: # pragma: no cover
  184. raise ValueError('%s at %s' %
  185. (errorStatus.prettyPrint(), errorIndex and
  186. varBinds[int(errorIndex)-1][0] or '?'))
  187. else:
  188. for varBind in varBinds:
  189. if varBind[1] != value: # pragma: no cover
  190. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  191. def _walk(self, *oid):
  192. oid = ObjectIdentity(*oid)
  193. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  194. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  195. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  196. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  197. oid.resolveWithMib(_mvc)
  198. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  199. self._eng, self._cd, self._targ, ContextData(),
  200. ObjectType(oid),
  201. lexicographicMode=False):
  202. if errorInd: # pragma: no cover
  203. raise ValueError(errorInd)
  204. elif errorStatus: # pragma: no cover
  205. raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?'))
  206. else:
  207. for varBind in varBinds:
  208. yield varBind
  209. def getportmapping(self):
  210. '''Return a port name mapping. Keys are the port index
  211. and the value is the name from the ifName entry.'''
  212. return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') }
  213. def findport(self, name):
  214. '''Look up a port name and return it's port index. This
  215. looks up via the ifName table in IF-MIB.'''
  216. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0]
  217. def getvlanname(self, vlan):
  218. '''Return the name for the vlan.'''
  219. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  220. return str(v).decode('utf-8')
  221. def createvlan(self, vlan, name):
  222. # createAndGo(4)
  223. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  224. int(vlan)), 4)
  225. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  226. name)
  227. def deletevlan(self, vlan):
  228. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  229. int(vlan)), 6) # destroy(6)
  230. def getvlans(self):
  231. '''Return an iterator with all the vlan ids.'''
  232. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus'))
  233. def staticvlans(self):
  234. '''Return an iterator of the staticly defined/configured
  235. vlans. This sometimes excludes special built in vlans,
  236. like vlan 1.'''
  237. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName'))
  238. def getpvid(self):
  239. '''Returns a dictionary w/ the interface index as the key,
  240. and the pvid of the interface.'''
  241. return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') }
  242. def getegress(self, *vlans):
  243. r = { x[-1]: _octstrtobits(y) for x, y in
  244. self._getmany(*(('Q-BRIDGE-MIB',
  245. 'dot1qVlanStaticEgressPorts', x) for x in vlans)) }
  246. return r
  247. def getuntagged(self, *vlans):
  248. r = { x[-1]: _octstrtobits(y) for x, y in
  249. self._getmany(*(('Q-BRIDGE-MIB',
  250. 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) }
  251. return r
  252. if __name__ == '__main__': # pragma: no cover
  253. print `checkchanges('data')`
  254. class _TestMisc(unittest.TestCase):
  255. def setUp(self):
  256. import test_data
  257. self._test_data = test_data
  258. def test_intstobits(self):
  259. self.assertEqual(_intstobits(1, 5, 10), '1000100001')
  260. self.assertEqual(_intstobits(3, 4, 9), '001100001')
  261. def test_octstrtobits(self):
  262. self.assertEqual(_octstrtobits('\x00'), '0' * 8)
  263. self.assertEqual(_octstrtobits('\xff'), '1' * 8)
  264. self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4)
  265. self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4)
  266. def test_cmpbits(self):
  267. self.assertTrue(_cmpbits('111000', '111'))
  268. self.assertTrue(_cmpbits('000111000', '000111'))
  269. self.assertTrue(_cmpbits('11', '11'))
  270. self.assertTrue(_cmpbits('0', '000'))
  271. self.assertFalse(_cmpbits('0011', '11'))
  272. self.assertFalse(_cmpbits('11', '0011'))
  273. self.assertFalse(_cmpbits('10', '000'))
  274. self.assertFalse(_cmpbits('0', '1000'))
  275. self.assertFalse(_cmpbits('00010', '000'))
  276. self.assertFalse(_cmpbits('0', '001000'))
  277. def test_pvidegressuntagged(self):
  278. data = {
  279. 1: {
  280. 'u': [ 1, 5, 10 ] + range(13, 20),
  281. 't': [ 'lag2', 6, 7 ],
  282. },
  283. 10: {
  284. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  285. },
  286. 13: {
  287. 'u': [ 4, 9 ],
  288. 't': [ 'lag2', 6, 7 ],
  289. },
  290. 14: {
  291. 't': [ 'lag2' ],
  292. },
  293. }
  294. swconf = SwitchConfig('', '', data, [ 'lag3' ])
  295. lookup = {
  296. 'lag2': 30,
  297. 'lag3': 31,
  298. }
  299. lufun = lookup.__getitem__
  300. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  301. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  302. [ (30, 10) ]))
  303. # That a pvid mapping
  304. res = getpvidmapping(data, lufun)
  305. # is correct
  306. self.assertEqual(res, check)
  307. self.assertEqual(swconf.getportlist(lufun),
  308. set(xrange(1, 11)) | set(xrange(13, 20)) | set(lookup.values()))
  309. checkegress = {
  310. 1: '1000111001001111111' + '0' * (30 - 20) + '1',
  311. 10: '01100111' + '0' * (30 - 9) + '1',
  312. 13: '000101101' + '0' * (30 - 10) + '1',
  313. 14: '0' * (30 - 1) + '1',
  314. }
  315. self.assertEqual(getegress(data, lufun), checkegress)
  316. checkuntagged = {
  317. 1: '1000100001001111111',
  318. 10: '01100111' + '0' * (30 - 9) + '1',
  319. 13: '000100001',
  320. 14: '',
  321. }
  322. self.assertEqual(getuntagged(data, lufun), checkuntagged)
  323. #@unittest.skip('foo')
  324. @mock.patch('vlanmang.SNMPSwitch.getuntagged')
  325. @mock.patch('vlanmang.SNMPSwitch.getegress')
  326. @mock.patch('vlanmang.SNMPSwitch.getpvid')
  327. @mock.patch('vlanmang.SNMPSwitch.getportmapping')
  328. @mock.patch('importlib.import_module')
  329. def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged):
  330. # that import returns the test data
  331. imprt.side_effect = itertools.repeat(self._test_data)
  332. # that getportmapping returns the following dict
  333. ports = { x: 'g%d' % x for x in xrange(1, 24) }
  334. ports[30] = 'lag1'
  335. ports[31] = 'lag2'
  336. ports[32] = 'lag3'
  337. portmapping.side_effect = itertools.repeat(ports)
  338. # that the switch's pvid returns
  339. spvid = { x: 283 for x in xrange(1, 24) }
  340. spvid[30] = 5
  341. gpvid.side_effect = itertools.repeat(spvid)
  342. # the the extra port is caught
  343. self.assertRaises(ValueError, checkchanges, 'data')
  344. # that the functions were called
  345. imprt.assert_called_with('data')
  346. portmapping.assert_called()
  347. # XXX - check that an ignore statement is honored
  348. # delete the extra port
  349. del ports[32]
  350. # that the egress data provided
  351. gegress.side_effect = [ {
  352. 1: '1' * 10,
  353. 5: '1' * 10,
  354. 283: '00000000111111111110011000000100000',
  355. } ]
  356. # that the untagged data provided
  357. guntagged.side_effect = [ {
  358. 1: '1' * 10,
  359. 5: '1' * 8 + '0' * 10,
  360. 283: '00000000111111111110011',
  361. } ]
  362. res = checkchanges('data')
  363. validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \
  364. [ ('setpvid', 20, 1, 283),
  365. ('setpvid', 21, 1, 283),
  366. ('setpvid', 30, 1, 5),
  367. ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10),
  368. ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10),
  369. ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10),
  370. ]
  371. self.assertEqual(set(res), set(validres))
  372. _skipSwitchTests = True
  373. class _TestSwitch(unittest.TestCase):
  374. def setUp(self):
  375. # If we don't have it, pretend it's true for now and
  376. # we'll recheck it later
  377. model = 'GS108T smartSwitch'
  378. if getattr(self, 'switchmodel', model) != model or \
  379. _skipSwitchTests: # pragma: no cover
  380. self.skipTest('Need a GS108T switch to run these tests')
  381. args = open('test.creds').read().split()
  382. self.switch = SNMPSwitch(*args)
  383. self.switchmodel = self.switch._get(('ENTITY-MIB',
  384. 'entPhysicalModelName', 1))
  385. if self.switchmodel != model: # pragma: no cover
  386. self.skipTest('Need a GS108T switch to run these tests')
  387. def test_misc(self):
  388. switch = self.switch
  389. self.assertEqual(switch.findport('g1'), 1)
  390. self.assertEqual(switch.findport('l1'), 14)
  391. def test_portnames(self):
  392. switch = self.switch
  393. resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
  394. resp.update({ 13: 'cpu' })
  395. resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
  396. self.assertEqual(switch.getportmapping(), resp)
  397. def test_egress(self):
  398. switch = self.switch
  399. egress = switch.getegress(1, 2, 3)
  400. checkegress = {
  401. 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23,
  402. 2: '0' * 8 * 5,
  403. 3: '0' * 8 * 5,
  404. }
  405. self.assertEqual(egress, checkegress)
  406. def test_untagged(self):
  407. switch = self.switch
  408. untagged = switch.getuntagged(1, 2, 3)
  409. checkuntagged = {
  410. 1: '1' * 8 * 5,
  411. 2: '1' * 8 * 5,
  412. 3: '1' * 8 * 5,
  413. }
  414. self.assertEqual(untagged, checkuntagged)
  415. def test_vlan(self):
  416. switch = self.switch
  417. existingvlans = set(switch.getvlans())
  418. while True:
  419. testvlan = random.randint(1,4095)
  420. if testvlan not in existingvlans:
  421. break
  422. # Test that getting a non-existant vlans raises an exception
  423. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  424. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  425. pvidres = { x: 1 for x in xrange(1, 9) }
  426. pvidres.update({ x: 1 for x in xrange(14, 18) })
  427. self.assertEqual(switch.getpvid(), pvidres)
  428. testname = 'Sometestname'
  429. # Create test vlan
  430. switch.createvlan(testvlan, testname)
  431. try:
  432. # make sure the test vlan was created
  433. self.assertIn(testvlan, set(switch.staticvlans()))
  434. self.assertEqual(testname, switch.getvlanname(testvlan))
  435. finally:
  436. switch.deletevlan(testvlan)