VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

977 lines
28 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import importlib
  7. import itertools
  8. import mock
  9. import random
  10. import unittest
  11. _mbuilder = MibBuilder()
  12. _mvc = MibViewController(_mbuilder)
  13. #import data
  14. # received packages
  15. # pvid: dot1qPvid
  16. #
  17. # tx packets:
  18. # dot1qVlanStaticEgressPorts
  19. # dot1qVlanStaticUntaggedPorts
  20. #
  21. # vlans:
  22. # dot1qVlanCurrentTable
  23. # lists ALL vlans, including baked in ones
  24. #
  25. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  26. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  27. # are still accessible via that oid
  28. #
  29. # LLDP:
  30. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  31. class SwitchConfig(object):
  32. '''This is a simple object to store switch configuration for
  33. the checkchanges() function.
  34. host -- The host of the switch you are maintaining configuration of.
  35. authargs -- This is a dictionary of kwargs to pass to SNMPSwitch.
  36. If SNMPv1 (insecure) is used, pass dict(community='communitystr').
  37. Example for SNMPv3 where the key for both authentication and
  38. encryption are the same:
  39. dict(username='username', authKey=key, privKey=key)
  40. vlanconf -- This is a dictionary w/ vlans as the key. Each value has
  41. a dictionary that contains keys, 'u' or 't', each of which
  42. contains the port that traffic should be sent untagged ('u') or
  43. tagged ('t'). Note that the Pvid (vlan of traffic that is
  44. received when untagged), is set to match the 'u' definition. The
  45. port is either an integer, which maps directly to the switch's
  46. index number, or it can be a string, which will be looked up via
  47. the IF-MIB::ifName table.
  48. Example specifies that VLANs 1 and 2 will be transmitted as tagged
  49. packets on the port named 'lag1'. That ports 1, 2, 3, 4 and 5 will
  50. be untagged on VLAN 1, and ports 6, 7, 8 and 9 will be untagged on
  51. VLAN 2:
  52. { 1: {
  53. 'u': [ 1, 2, 3, 4, 5 ],
  54. 't': [ 'lag1' ],
  55. },
  56. 2: {
  57. 'u': [ 6, 7, 8, 9 ],
  58. 't': [ 'lag1' ],
  59. },
  60. }
  61. ignports -- Ports that will be ignored and not required to be
  62. configured. List any ports that will not be active here, such as
  63. any unused lag ports.
  64. '''
  65. def __init__(self, host, authargs, vlanconf, ignports):
  66. self._host = host
  67. self._authargs = authargs
  68. self._vlanconf = vlanconf
  69. self._ignports = ignports
  70. @property
  71. def host(self):
  72. return self._host
  73. @property
  74. def authargs(self):
  75. return self._authargs
  76. @property
  77. def vlanconf(self):
  78. return self._vlanconf
  79. @property
  80. def ignports(self):
  81. return self._ignports
  82. def getportlist(self, lookupfun):
  83. '''Return a set of all the ports indexes in data. This
  84. includes, both vlanconf and ignports. Any ports using names
  85. will be resolved by being passed to the provided lookupfun.'''
  86. res = []
  87. for id in self._vlanconf:
  88. res.extend(self._vlanconf[id].get('u', []))
  89. res.extend(self._vlanconf[id].get('t', []))
  90. # add in the ignore ports
  91. res.extend(self.ignports)
  92. # eliminate dups so that lookupfun isn't called as often
  93. res = set(res)
  94. return set(getidxs(res, lookupfun))
  95. def _octstrtobits(os):
  96. '''Convert a string into a list of bits. Easier to figure out what
  97. ports are set.'''
  98. num = 1 # leading 1 to make sure leading zeros are not stripped
  99. for i in str(os):
  100. num = (num << 8) | ord(i)
  101. return bin(num)[3:]
  102. def _intstobits(*ints):
  103. '''Convert the int args to a string of bits in the expected format
  104. that SNMP expects for them. The results will be a string of '1's
  105. and '0's where the first one represents 1, and second one
  106. representing 2 and so on.'''
  107. v = 0
  108. for i in ints:
  109. v |= 1 << i
  110. r = list(bin(v)[2:-1])
  111. r.reverse()
  112. return ''.join(r)
  113. def _cmpbits(a, b):
  114. '''Compare two strings of bits to make sure they are equal.
  115. Trailing 0's are ignored.'''
  116. try:
  117. last1a = a.rindex('1')
  118. except ValueError:
  119. last1a = -1
  120. a = ''
  121. try:
  122. last1b = b.rindex('1')
  123. except ValueError:
  124. last1b = -1
  125. b = ''
  126. if last1a != -1:
  127. a = a[:last1a + 1]
  128. if last1b != -1:
  129. b = b[:last1b + 1]
  130. return a == b
  131. import vlanmang
  132. def checkchanges(module):
  133. '''Function to check for any differences between the switch, and the
  134. configured state.
  135. The parameter module is a string to the name of a python module. It
  136. will be imported, and any names that reference a vlanmang.SwitchConfig
  137. class will be validate that the configuration matches. If it does not,
  138. the returned list will contain a set of tuples, each one containing
  139. (verb, arg1, arg2, switcharg2). verb is what needs to be changed.
  140. arg1 is either the port (for setting Pvid), or the VLAN that needs to
  141. be configured. arg2 is what it needs to be set to. switcharg2 is
  142. what the switch is currently configured to, so that you can easily
  143. see what the effect of the configuration change is.
  144. '''
  145. mod = importlib.import_module(module)
  146. mods = [ (k, v) for k, v in mod.__dict__.iteritems() if isinstance(v, vlanmang.SwitchConfig) ]
  147. res = []
  148. for name, i in mods:
  149. #print 'probing %s' % `name`
  150. vlans = i.vlanconf.keys()
  151. switch = SNMPSwitch(i.host, **i.authargs)
  152. switchpvid = switch.getpvid()
  153. portmapping = switch.getportmapping()
  154. invportmap = { y: x for x, y in portmapping.iteritems() }
  155. lufun = invportmap.__getitem__
  156. # get complete set of ports
  157. portlist = i.getportlist(lufun)
  158. ports = set(portmapping.iterkeys())
  159. # make sure switch agrees w/ them all
  160. if ports != portlist:
  161. raise ValueError('missing or extra ports found: %s' %
  162. `ports.symmetric_difference(portlist)`)
  163. # compare pvid
  164. pvidmap = getpvidmapping(i.vlanconf, lufun)
  165. switchpvid = switch.getpvid()
  166. res.extend((switch, name, 'setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in
  167. pvidmap.iteritems() if switchpvid[idx] != vlan)
  168. # compare egress & untagged
  169. switchegress = switch.getegress(*vlans)
  170. egress = getegress(i.vlanconf, lufun)
  171. switchuntagged = switch.getuntagged(*vlans)
  172. untagged = getuntagged(i.vlanconf, lufun)
  173. for i in vlans:
  174. if not _cmpbits(switchegress[i], egress[i]):
  175. res.append((switch, name, 'setegress', i, egress[i], switchegress[i]))
  176. if not _cmpbits(switchuntagged[i], untagged[i]):
  177. res.append((switch, name, 'setuntagged', i, untagged[i], switchuntagged[i]))
  178. return res
  179. def getidxs(lst, lookupfun):
  180. '''Take a list of ports, and if any are a string, replace them w/
  181. the value returned by lookupfun(s).
  182. Note that duplicates are not detected or removed, both in the
  183. original list, and the values returned by the lookup function
  184. may duplicate other values in the list.'''
  185. return [ lookupfun(i) if isinstance(i, str) else i for i in lst ]
  186. def getpvidmapping(data, lookupfun):
  187. '''Return a mapping from vlan based table to a port: vlan
  188. dictionary. This only looks at that untagged part of the vlan
  189. configuration, and is used for finding what a port's Pvid should
  190. be.'''
  191. res = []
  192. for id in data:
  193. for i in data[id].get('u', []):
  194. if isinstance(i, str):
  195. i = lookupfun(i)
  196. res.append((i, id))
  197. return dict(res)
  198. def getegress(data, lookupfun):
  199. '''Return a dictionary, keyed by VLAN id with a bit string of ports
  200. that need to be enabled for egress. This include both tagged and
  201. untagged traffic.'''
  202. r = {}
  203. for id in data:
  204. r[id] = _intstobits(*(getidxs(data[id].get('u', []),
  205. lookupfun) + getidxs(data[id].get('t', []), lookupfun)))
  206. return r
  207. def getuntagged(data, lookupfun):
  208. '''Return a dictionary, keyed by VLAN id with a bit string of ports
  209. that need to be enabled for untagged egress.'''
  210. r = {}
  211. for id in data:
  212. r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun))
  213. return r
  214. class SNMPSwitch(object):
  215. '''A class for manipulating switches via standard SNMP MIBs.'''
  216. def __init__(self, host, community=None, username=None, authKey=None, authProtocol=usmHMACSHAAuthProtocol, privKey=None, privProtocol=None):
  217. '''Create a instance to read data and program a switch via
  218. SNMP.
  219. Args:
  220. host -- Host name or IP address of the switch.
  221. community -- If using SNMPv1 (not recommended, insecure), this
  222. is the community name to authenticate.
  223. username -- The username to authenticate when using SNMPv3.
  224. This varies, some cases it can be programmed and a
  225. specific user is created, in other cases, it is hard coded
  226. to a user like 'admin'.
  227. authKey -- This is the key string used to authenticate the
  228. SNMP requests.
  229. authProtocol -- This is protocol used to authenticate the
  230. SNMP requests. It is one of the values passed to
  231. authProtocol of pysnmp's UsmUserData as documented at:
  232. http://snmplabs.com/pysnmp/docs/api-reference.html#pysnmp.hlapi.UsmUserData
  233. privKey -- This is the key string used to encrypt the SNMP
  234. requests.
  235. privProtocol -- This is protocol used to encrypt the
  236. SNMP requests. It is one of the values passed to
  237. privProtocol of pysnmp's UsmUserData as documented at:
  238. http://snmplabs.com/pysnmp/docs/api-reference.html#pysnmp.hlapi.UsmUserData
  239. '''
  240. if community is not None and username is not None:
  241. raise ValueError('only one of community and username is allowed to be specified')
  242. self._eng = SnmpEngine()
  243. if community is not None:
  244. self._auth = CommunityData(community, mpModel=0)
  245. else:
  246. args = (username, authKey, )
  247. kwargs = { 'authProtocol': authProtocol }
  248. if privKey is not None:
  249. args += (privKey,)
  250. kwargs['privProtocol'] = \
  251. usmAesCfb256Protocol if privProtocol is \
  252. None else privProtocol
  253. self._auth = UsmUserData(*args, **kwargs)
  254. self._targ = UdpTransportTarget((host, 161))
  255. def __repr__(self): # pragma: no cover
  256. return '<SNMPSwitch: auth=%s, targ=%s>' % (`self._auth`, `self._targ`)
  257. def _getmany(self, *oids):
  258. woids = [ ObjectIdentity(*oid) for oid in oids ]
  259. [ oid.resolveWithMib(_mvc) for oid in woids ]
  260. errorInd, errorStatus, errorIndex, varBinds = \
  261. next(getCmd(self._eng, self._auth, self._targ,
  262. ContextData(), *(ObjectType(oid) for oid in woids)))
  263. if errorInd: # pragma: no cover
  264. raise ValueError(errorInd)
  265. elif errorStatus:
  266. if str(errorStatus) == 'tooBig' and len(oids) > 1:
  267. # split the request in two
  268. pivot = len(oids) / 2
  269. a = self._getmany(*oids[:pivot])
  270. b = self._getmany(*oids[pivot:])
  271. return a + b
  272. raise ValueError('%s at %s' %
  273. (errorStatus.prettyPrint(), errorIndex and
  274. varBinds[int(errorIndex)-1][0] or '?'))
  275. else:
  276. if len(varBinds) != len(oids): # pragma: no cover
  277. raise ValueError('too many return values')
  278. return varBinds
  279. def _get(self, oid):
  280. varBinds = self._getmany(oid)
  281. varBind = varBinds[0]
  282. return varBind[1]
  283. def _set(self, oid, value):
  284. oid = ObjectIdentity(*oid)
  285. oid.resolveWithMib(_mvc)
  286. if isinstance(value, (int, long)):
  287. value = Integer(value)
  288. elif isinstance(value, str):
  289. value = OctetString(value)
  290. errorInd, errorStatus, errorIndex, varBinds = \
  291. next(setCmd(self._eng, self._auth, self._targ,
  292. ContextData(), ObjectType(oid, value)))
  293. if errorInd: # pragma: no cover
  294. raise ValueError(errorInd)
  295. elif errorStatus: # pragma: no cover
  296. raise ValueError('%s at %s' %
  297. (errorStatus.prettyPrint(), errorIndex and
  298. varBinds[int(errorIndex)-1][0] or '?'))
  299. else:
  300. for varBind in varBinds:
  301. if varBind[1] != value: # pragma: no cover
  302. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  303. def _walk(self, *oid):
  304. oid = ObjectIdentity(*oid)
  305. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  306. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  307. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  308. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  309. oid.resolveWithMib(_mvc)
  310. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  311. self._eng, self._auth, self._targ, ContextData(),
  312. ObjectType(oid),
  313. lexicographicMode=False):
  314. if errorInd: # pragma: no cover
  315. raise ValueError(errorInd)
  316. elif errorStatus: # pragma: no cover
  317. raise ValueError('%s at %s' %
  318. (errorStatus.prettyPrint(), errorIndex and
  319. varBinds[int(errorIndex)-1][0] or '?'))
  320. else:
  321. for varBind in varBinds:
  322. yield varBind
  323. def getportmapping(self):
  324. '''Return a port name mapping. Keys are the port index
  325. and the value is the name from the IF-MIB::ifName entry.'''
  326. return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB',
  327. 'ifName') }
  328. def findport(self, name):
  329. '''Look up a port name and return it's port index. This
  330. looks up via the ifName table in IF-MIB.'''
  331. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if
  332. str(x[1]) == name ][0]
  333. def getvlanname(self, vlan):
  334. '''Return the name for the vlan. This returns the value in
  335. Q-BRIDGE-MIB:dot1qVlanStaticName.'''
  336. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  337. return str(v).decode('utf-8')
  338. def createvlan(self, vlan, name):
  339. # createAndGo(4)
  340. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  341. int(vlan)), 4)
  342. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  343. name)
  344. def deletevlan(self, vlan):
  345. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  346. int(vlan)), 6) # destroy(6)
  347. def getvlans(self):
  348. '''Return an iterator with all the vlan ids.'''
  349. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB',
  350. 'dot1qVlanStatus'))
  351. def staticvlans(self):
  352. '''Return an iterator of the staticly defined/configured
  353. vlans. This sometimes excludes special built in vlans,
  354. like vlan 1.'''
  355. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB',
  356. 'dot1qVlanStaticName'))
  357. def getpvid(self):
  358. '''Returns a dictionary w/ the interface index as the key,
  359. and the pvid of the interface.'''
  360. return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB',
  361. 'dot1qPvid') }
  362. def setpvid(self, port, vlan):
  363. '''Set the port's Pvid to vlan. This means that any packet
  364. received by the port that is untagged, will be routed the
  365. the vlan.'''
  366. self._set(('Q-BRIDGE-MIB', 'dot1qPvid', int(port)), Gauge32(vlan))
  367. def getegress(self, *vlans):
  368. '''Get a dictionary keyed by the specified VLANs, where each
  369. value is a bit string that preresents what ports that
  370. particular VLAN will be transmitted on.'''
  371. r = { x[-1]: _octstrtobits(y) for x, y in
  372. self._getmany(*(('Q-BRIDGE-MIB',
  373. 'dot1qVlanStaticEgressPorts', x) for x in vlans)) }
  374. return r
  375. def setegress(self, vlan, ports):
  376. '''Set the ports which the specified VLAN will have packets
  377. transmitted as either tagged, if unset in untagged, or
  378. untagged, if set in untagged, to bit bit string specified
  379. by ports.'''
  380. value = OctetString.fromBinaryString(ports)
  381. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts',
  382. int(vlan)), value)
  383. def getuntagged(self, *vlans):
  384. '''Get a dictionary keyed by the specified VLANs, where each
  385. value is a bit string that preresents what ports that
  386. particular VLAN will be transmitted on as an untagged
  387. packet.'''
  388. r = { x[-1]: _octstrtobits(y) for x, y in
  389. self._getmany(*(('Q-BRIDGE-MIB',
  390. 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) }
  391. return r
  392. def setuntagged(self, vlan, ports):
  393. '''Set the ports which the specified VLAN will have packets
  394. transmitted as untagged to the bit string specified by ports.'''
  395. value = OctetString.fromBinaryString(ports)
  396. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts',
  397. int(vlan)), value)
  398. def main():
  399. import pprint
  400. import sys
  401. changes = checkchanges('data')
  402. if not changes:
  403. print 'No changes to apply.'
  404. sys.exit(0)
  405. pprint.pprint([ x[1:] for x in changes ])
  406. res = raw_input('Apply the changes? (type yes to apply): ')
  407. if res != 'yes':
  408. print 'not applying changes.'
  409. sys.exit(1)
  410. print 'applying...'
  411. failed = []
  412. prevname = None
  413. for switch, name, verb, arg1, arg2, oldarg in changes:
  414. if prevname != name:
  415. print 'Configuring switch %s...' % `name`
  416. prevname = name
  417. print '%s: %s %s' % (verb, arg1, `arg2`)
  418. try:
  419. fun = getattr(switch, verb)
  420. fun(arg1, arg2)
  421. pass
  422. except Exception as e:
  423. print 'failed'
  424. failed.append((verb, arg1, arg2, e))
  425. if failed:
  426. print '%d failed to apply, they are:' % len(failed)
  427. for verb, arg1, arg2, e in failed:
  428. print '%s: %s %s: %s' % (verb, arg1, arg2, `e`)
  429. if __name__ == '__main__': # pragma: no cover
  430. main()
  431. class _TestMisc(unittest.TestCase):
  432. def setUp(self):
  433. import test_data
  434. self.skipTest('foo')
  435. self._test_data = test_data
  436. def test_intstobits(self):
  437. self.assertEqual(_intstobits(1, 5, 10), '1000100001')
  438. self.assertEqual(_intstobits(3, 4, 9), '001100001')
  439. def test_octstrtobits(self):
  440. self.assertEqual(_octstrtobits('\x00'), '0' * 8)
  441. self.assertEqual(_octstrtobits('\xff'), '1' * 8)
  442. self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4)
  443. self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4)
  444. def test_cmpbits(self):
  445. self.assertTrue(_cmpbits('111000', '111'))
  446. self.assertTrue(_cmpbits('000111000', '000111'))
  447. self.assertTrue(_cmpbits('11', '11'))
  448. self.assertTrue(_cmpbits('0', '000'))
  449. self.assertFalse(_cmpbits('0011', '11'))
  450. self.assertFalse(_cmpbits('11', '0011'))
  451. self.assertFalse(_cmpbits('10', '000'))
  452. self.assertFalse(_cmpbits('0', '1000'))
  453. self.assertFalse(_cmpbits('00010', '000'))
  454. self.assertFalse(_cmpbits('0', '001000'))
  455. def test_pvidegressuntagged(self):
  456. data = {
  457. 1: {
  458. 'u': [ 1, 5, 10 ] + range(13, 20),
  459. 't': [ 'lag2', 6, 7 ],
  460. },
  461. 10: {
  462. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  463. },
  464. 13: {
  465. 'u': [ 4, 9 ],
  466. 't': [ 'lag2', 6, 7 ],
  467. },
  468. 14: {
  469. 't': [ 'lag2' ],
  470. },
  471. }
  472. swconf = SwitchConfig('', {}, data, [ 'lag3' ])
  473. lookup = {
  474. 'lag2': 30,
  475. 'lag3': 31,
  476. }
  477. lufun = lookup.__getitem__
  478. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  479. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  480. [ (30, 10) ]))
  481. # That a pvid mapping
  482. res = getpvidmapping(data, lufun)
  483. # is correct
  484. self.assertEqual(res, check)
  485. self.assertEqual(swconf.getportlist(lufun),
  486. set(xrange(1, 11)) | set(xrange(13, 20)) |
  487. set(lookup.values()))
  488. checkegress = {
  489. 1: '1000111001001111111' + '0' * (30 - 20) + '1',
  490. 10: '01100111' + '0' * (30 - 9) + '1',
  491. 13: '000101101' + '0' * (30 - 10) + '1',
  492. 14: '0' * (30 - 1) + '1',
  493. }
  494. self.assertEqual(getegress(data, lufun), checkegress)
  495. checkuntagged = {
  496. 1: '1000100001001111111',
  497. 10: '01100111' + '0' * (30 - 9) + '1',
  498. 13: '000100001',
  499. 14: '',
  500. }
  501. self.assertEqual(getuntagged(data, lufun), checkuntagged)
  502. @mock.patch('vlanmang.CommunityData')
  503. @mock.patch('vlanmang.getCmd')
  504. def test_v1auth(self, gc, cd):
  505. # That the CommunityData class returns an object
  506. cdobj = object()
  507. cd.side_effect = [ cdobj ]
  508. # That a switch passed a community string
  509. commstr = 'foobar'
  510. switch = SNMPSwitch(None, community=commstr)
  511. # That getCmd returns a valid object
  512. vb = [ [ None, None ] ]
  513. gc.side_effect = [ iter([[ None ] * 3 + [ vb ] ]) ]
  514. r = switch.getvlanname(1)
  515. # That getCmd was called
  516. gc.assert_called()
  517. # with the correct auth object
  518. calledcd = gc.call_args.args[1]
  519. self.assertIs(calledcd, cdobj)
  520. # and that CommunityData was called w/ the correct args
  521. cd.assert_called_with(commstr, mpModel=0)
  522. def test_badauth(self):
  523. # that when both community and username are provided
  524. # it raises a ValueError
  525. self.assertRaises(ValueError, SNMPSwitch, 'somehost',
  526. community='foo', username='bar')
  527. @mock.patch('vlanmang.UsmUserData')
  528. @mock.patch('vlanmang.getCmd')
  529. def test_v3auth(self, gc, uud):
  530. # That the UsmUserData class returns an object
  531. uudobj = object()
  532. uud.side_effect = [ uudobj ] * 5
  533. # That a switch passed v3 auth data
  534. username = 'someuser'
  535. authKey = 'authKey'
  536. switch = SNMPSwitch(None, username=username, authKey=authKey)
  537. # That getCmd returns a valid object
  538. vb = [ [ None, None ] ]
  539. gc.side_effect = [ iter([[ None ] * 3 + [ vb ] ]) ] * 10
  540. r = switch.getvlanname(1)
  541. # That getCmd was called
  542. gc.assert_called()
  543. # with the correct auth object
  544. calleduud = gc.call_args.args[1]
  545. self.assertIs(calleduud, uudobj)
  546. # and that UsmUserData was called w/ the correct args
  547. uud.assert_called_with(username, authKey,
  548. authProtocol=usmHMACSHAAuthProtocol)
  549. # Reset the usm data
  550. uud.reset_mock()
  551. # that it can be called with a privKey
  552. privKey = 'privKey'
  553. switch = SNMPSwitch(None, username=username, authKey=authKey,
  554. privKey=privKey)
  555. # and that UsmUserData was called w/ the correct args
  556. uud.assert_called_with(username, authKey, privKey,
  557. authProtocol=usmHMACSHAAuthProtocol,
  558. privProtocol=usmAesCfb256Protocol)
  559. # Reset the usm data
  560. uud.reset_mock()
  561. # that it can be called with an alternate privProtocol
  562. switch = SNMPSwitch(None, username=username, authKey=authKey,
  563. privKey=privKey, privProtocol=usmDESPrivProtocol)
  564. # and that UsmUserData was called w/ the correct args
  565. uud.assert_called_with(username, authKey, privKey,
  566. authProtocol=usmHMACSHAAuthProtocol,
  567. privProtocol=usmDESPrivProtocol)
  568. # Reset the usm data
  569. uud.reset_mock()
  570. # that it can be called with an alternate authProtocol
  571. switch = SNMPSwitch(None, username=username, authKey=authKey,
  572. authProtocol=usmHMACMD5AuthProtocol, privKey=privKey,
  573. privProtocol=usmDESPrivProtocol)
  574. # and that UsmUserData was called w/ the correct args
  575. uud.assert_called_with(username, authKey, privKey,
  576. authProtocol=usmHMACMD5AuthProtocol,
  577. privProtocol=usmDESPrivProtocol)
  578. #@unittest.skip('foo')
  579. @mock.patch('vlanmang.SNMPSwitch.getuntagged')
  580. @mock.patch('vlanmang.SNMPSwitch.getegress')
  581. @mock.patch('vlanmang.SNMPSwitch.getpvid')
  582. @mock.patch('vlanmang.SNMPSwitch.getportmapping')
  583. @mock.patch('importlib.import_module')
  584. def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged):
  585. # that import returns the test data
  586. imprt.side_effect = itertools.repeat(self._test_data)
  587. # that getportmapping returns the following dict
  588. ports = { x: 'g%d' % x for x in xrange(1, 24) }
  589. ports[30] = 'lag1'
  590. ports[31] = 'lag2'
  591. ports[32] = 'lag3'
  592. portmapping.side_effect = itertools.repeat(ports)
  593. # that the switch's pvid returns
  594. spvid = { x: 283 for x in xrange(1, 24) }
  595. spvid[30] = 5
  596. gpvid.side_effect = itertools.repeat(spvid)
  597. # the the extra port is caught
  598. self.assertRaises(ValueError, checkchanges, 'data')
  599. # that the functions were called
  600. imprt.assert_called_with('data')
  601. portmapping.assert_called()
  602. # XXX - check that an ignore statement is honored
  603. # delete the extra port
  604. del ports[32]
  605. # that the egress data provided
  606. gegress.side_effect = [ {
  607. 1: '1' * 10,
  608. 5: '1' * 10,
  609. 283: '00000000111111111110011000000100000',
  610. } ]
  611. # that the untagged data provided
  612. guntagged.side_effect = [ {
  613. 1: '1' * 10,
  614. 5: '1' * 8 + '0' * 10,
  615. 283: '00000000111111111110011',
  616. } ]
  617. res = checkchanges('data')
  618. # Make sure that the first one are all instances of SNMPSwitch
  619. # XXX make sure args for them are correct.
  620. self.assertTrue(all(isinstance(x[0], SNMPSwitch) for x in res))
  621. # Make sure that the name provided is correct
  622. self.assertTrue(all(x[1] == 'distswitch' for x in res))
  623. res = [ x[2:] for x in res ]
  624. validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \
  625. [ ('setpvid', 20, 1, 283),
  626. ('setpvid', 21, 1, 283),
  627. ('setpvid', 30, 1, 5),
  628. ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1',
  629. '1' * 10),
  630. ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1',
  631. '1' * 10),
  632. ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 +
  633. '1', '1' * 10),
  634. ]
  635. self.assertEqual(set(res), set(validres))
  636. class _TestSNMPSwitch(unittest.TestCase):
  637. def setUp(self):
  638. self.skipTest('foo')
  639. @mock.patch('vlanmang.SNMPSwitch._getmany')
  640. def test_get(self, gm):
  641. # that a switch
  642. switch = SNMPSwitch(None, community=None)
  643. # when _getmany returns this structure
  644. retval = object()
  645. gm.side_effect = [[[ None, retval ]]]
  646. arg = object()
  647. # will return the correct value
  648. self.assertIs(switch._get(arg), retval)
  649. # and call _getmany w/ the correct arg
  650. gm.assert_called_with(arg)
  651. @mock.patch('pysnmp.hlapi.ContextData')
  652. @mock.patch('vlanmang.getCmd')
  653. def test_getmany(self, gc, cd):
  654. # that a switch
  655. switch = SNMPSwitch(None, community=None)
  656. lookup = { x: chr(x) for x in xrange(1, 10) }
  657. # when getCmd returns tooBig when too many oids are asked for
  658. def custgetcmd(eng, cd, targ, contextdata, *oids):
  659. # induce a too big error
  660. if len(oids) > 3:
  661. res = ( None, 'tooBig', None, None )
  662. else:
  663. #import pdb; pdb.set_trace()
  664. [ oid.resolveWithMib(_mvc) for oid in oids ]
  665. oids = [ ObjectType(x[0],
  666. OctetString(lookup[x[0][-1]])) for x in oids ]
  667. [ oid.resolveWithMib(_mvc) for oid in oids ]
  668. res = ( None, None, None, oids )
  669. return iter([res])
  670. gc.side_effect = custgetcmd
  671. #import pdb; pdb.set_trace()
  672. res = switch.getegress(*xrange(1, 10))
  673. # will still return the complete set of results
  674. self.assertEqual(res, { x: _octstrtobits(lookup[x]) for x in
  675. xrange(1, 10) })
  676. _skipSwitchTests = False
  677. class _TestSwitch(unittest.TestCase):
  678. def setUp(self):
  679. # If we don't have it, pretend it's true for now and
  680. # we'll recheck it later
  681. model = 'GS108T smartSwitch'
  682. if getattr(self, 'switchmodel', model) != model or \
  683. _skipSwitchTests: # pragma: no cover
  684. self.skipTest('Need a GS108T switch to run these tests')
  685. host, authkey, privkey = open('test.creds').read().split()
  686. self.switch = SNMPSwitch(host, authKey=authkey,
  687. privKey=privkey, privProtocol=usmDESPrivProtocol)
  688. self.switchmodel = self.switch._get(('ENTITY-MIB',
  689. 'entPhysicalModelName', 1))
  690. if self.switchmodel != model: # pragma: no cover
  691. self.skipTest('Need a GS108T switch to run these tests')
  692. def test_misc(self):
  693. switch = self.switch
  694. self.assertEqual(switch.findport('g1'), 1)
  695. self.assertEqual(switch.findport('l1'), 14)
  696. def test_portnames(self):
  697. switch = self.switch
  698. resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
  699. resp.update({ 13: 'cpu' })
  700. resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
  701. self.assertEqual(switch.getportmapping(), resp)
  702. def test_egress(self):
  703. switch = self.switch
  704. egress = switch.getegress(1, 2, 3)
  705. checkegress = {
  706. 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23,
  707. 2: '0' * 8 * 5,
  708. 3: '0' * 8 * 5,
  709. }
  710. self.assertEqual(egress, checkegress)
  711. def test_untagged(self):
  712. switch = self.switch
  713. untagged = switch.getuntagged(1, 2, 3)
  714. checkuntagged = {
  715. 1: '1' * 8 * 5,
  716. 2: '1' * 8 * 5,
  717. 3: '1' * 8 * 5,
  718. }
  719. self.assertEqual(untagged, checkuntagged)
  720. def test_vlan(self):
  721. switch = self.switch
  722. existingvlans = set(switch.getvlans())
  723. while True:
  724. testvlan = random.randint(1,4095)
  725. if testvlan not in existingvlans:
  726. break
  727. # Test that getting a non-existant vlans raises an exception
  728. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  729. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  730. pvidres = { x: 1 for x in xrange(1, 9) }
  731. pvidres.update({ x: 1 for x in xrange(14, 18) })
  732. self.assertEqual(switch.getpvid(), pvidres)
  733. testname = 'Sometestname'
  734. # Create test vlan
  735. switch.createvlan(testvlan, testname)
  736. testport = None
  737. try:
  738. # make sure the test vlan was created
  739. self.assertIn(testvlan, set(switch.staticvlans()))
  740. self.assertEqual(testname, switch.getvlanname(testvlan))
  741. switch.setegress(testvlan, '00100')
  742. pvidmap = switch.getpvid()
  743. testport = 3
  744. egressports = switch.getegress(testvlan)
  745. self.assertEqual(egressports[testvlan], '00100000' +
  746. '0' * 8 * 4)
  747. switch.setuntagged(testvlan, '00100')
  748. untaggedports = switch.getuntagged(testvlan)
  749. self.assertEqual(untaggedports[testvlan], '00100000' +
  750. '0' * 8 * 4)
  751. switch.setpvid(testport, testvlan)
  752. self.assertEqual(switch.getpvid()[testport], testvlan)
  753. finally:
  754. if testport:
  755. switch.setpvid(testport, pvidmap[3])
  756. switch.deletevlan(testvlan)