A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

627 lines
16 KiB

  1. #!/usr/bin/env python
  2. '''A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. of pickle.
  4. It uses a profile of ASN.1.
  5. All lengths must be specified. That is that End-of-contents octets
  6. MUST not be used. The shorted form of length encoding MUST be used.
  7. A longer length encoding MUST be rejected.'''
  8. __author__ = 'John-Mark Gurney'
  9. __copyright__ = 'Copyright 2016 John-Mark Gurney. All rights reserved.'
  10. __license__ = '2-clause BSD license'
  11. # Copyright 2016, John-Mark Gurney
  12. # All rights reserved.
  13. #
  14. # Redistribution and use in source and binary forms, with or without
  15. # modification, are permitted provided that the following conditions are met:
  16. #
  17. # 1. Redistributions of source code must retain the above copyright notice, this
  18. # list of conditions and the following disclaimer.
  19. # 2. Redistributions in binary form must reproduce the above copyright notice,
  20. # this list of conditions and the following disclaimer in the documentation
  21. # and/or other materials provided with the distribution.
  22. #
  23. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  24. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  25. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  26. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
  27. # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  28. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  29. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  30. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  31. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  32. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  33. #
  34. # The views and conclusions contained in the software and documentation are those
  35. # of the authors and should not be interpreted as representing official policies,
  36. # either expressed or implied, of the Project.
  37. import datetime
  38. import math
  39. import mock
  40. import os
  41. import pdb
  42. import sys
  43. import unittest
  44. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  45. def _numtostr(n):
  46. hs = '%x' % n
  47. if len(hs) & 1 == 1:
  48. hs = '0' + hs
  49. bs = hs.decode('hex')
  50. return bs
  51. def _encodelen(l):
  52. '''Takes l as a length value, and returns a byte string that
  53. represents l per ASN.1 rules.'''
  54. if l < 128:
  55. return chr(l)
  56. bs = _numtostr(l)
  57. return chr(len(bs) | 0x80) + bs
  58. def _decodelen(d, pos=0):
  59. '''Returns the length, and number of bytes required.'''
  60. odp = ord(d[pos])
  61. if odp < 128:
  62. return ord(d[pos]), 1
  63. else:
  64. l = odp & 0x7f
  65. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  66. class Test_codelen(unittest.TestCase):
  67. _testdata = [
  68. (2, '\x02'),
  69. (127, '\x7f'),
  70. (128, '\x81\x80'),
  71. (255, '\x81\xff'),
  72. (256, '\x82\x01\x00'),
  73. (65536-1, '\x82\xff\xff'),
  74. (65536, '\x83\x01\x00\x00'),
  75. ]
  76. def test_el(self):
  77. for i, j in self._testdata:
  78. self.assertEqual(_encodelen(i), j)
  79. self.assertEqual(_decodelen(j), (i, len(j)))
  80. def _splitfloat(f):
  81. m, e = math.frexp(f)
  82. # XXX - less than ideal
  83. while m != math.trunc(m):
  84. m *= 2
  85. e -= 1
  86. return m, e
  87. class TestSplitFloat(unittest.TestCase):
  88. def test_sf(self):
  89. for a, b in [ (0x2421, -32), (0x5382f, 238),
  90. (0x1fa8c3b094adf1, 971) ]:
  91. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  92. class ASN1Coder(object):
  93. '''A class that contains an PASN.1 encoder/decoder.
  94. Exports two methods, loads and dumps.'''
  95. def __init__(self, coerce=None):
  96. '''If the arg coerce is provided, when dumping the object,
  97. if the type is not found, the coerce function will be called
  98. with the obj. It is expected to return a tuple of a string
  99. and an object that has the method w/ the string as defined:
  100. 'bool': __nonzero__
  101. 'dict': iteritems
  102. 'float': compatible w/ float
  103. 'int': compatible w/ int
  104. 'list': __iter__
  105. 'set': __iter__
  106. 'bytes': __str__
  107. 'null': no method needed
  108. 'unicode': encode method returns UTF-8 encoded bytes
  109. 'datetime': strftime and microsecond
  110. '''
  111. self.coerce = coerce
  112. _typemap = {
  113. bool: 'bool',
  114. dict: 'dict',
  115. float: 'float',
  116. int: 'int',
  117. list: 'list',
  118. long: 'int',
  119. set: 'set',
  120. str: 'bytes',
  121. type(None): 'null',
  122. unicode: 'unicode',
  123. #decimal.Decimal: 'float',
  124. datetime.datetime: 'datetime',
  125. #datetime.timedelta: 'timedelta',
  126. }
  127. _tagmap = {
  128. '\x01': 'bool',
  129. '\x02': 'int',
  130. '\x04': 'bytes',
  131. '\x05': 'null',
  132. '\x09': 'float',
  133. '\x0c': 'unicode',
  134. '\x18': 'datetime',
  135. '\x30': 'list',
  136. '\x31': 'set',
  137. '\xe0': 'dict',
  138. }
  139. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  140. @staticmethod
  141. def enc_int(obj):
  142. l = obj.bit_length()
  143. l += 1 # space for sign bit
  144. l = (l + 7) // 8
  145. if obj < 0:
  146. obj += 1 << (l * 8) # twos-complement conversion
  147. v = _numtostr(obj)
  148. if len(v) != l:
  149. # XXX - is this a problem for signed values?
  150. v = '\x00' + v # add sign octect
  151. return _encodelen(l) + v
  152. @staticmethod
  153. def dec_int(d, pos, end):
  154. if pos == end:
  155. return 0, end
  156. v = int(d[pos:end].encode('hex'), 16)
  157. av = 1 << ((end - pos) * 8 - 1) # sign bit
  158. if v > av:
  159. v -= av * 2 # twos-complement conversion
  160. return v, end
  161. @staticmethod
  162. def enc_bool(obj):
  163. return '\x01' + ('\xff' if obj else '\x00')
  164. def dec_bool(self, d, pos, end):
  165. v = self.dec_int(d, pos, end)[0]
  166. if v not in (-1, 0):
  167. raise ValueError('invalid bool value: %d' % v)
  168. return bool(v), end
  169. @staticmethod
  170. def enc_null(obj):
  171. return '\x00'
  172. @staticmethod
  173. def dec_null(d, pos, end):
  174. return None, end
  175. def enc_dict(self, obj):
  176. #it = list(obj.iteritems())
  177. #it.sort()
  178. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in
  179. obj.iteritems())
  180. return _encodelen(len(r)) + r
  181. def dec_dict(self, d, pos, end):
  182. r = {}
  183. vend = pos
  184. while pos < end:
  185. k, kend = self._loads(d, pos, end)
  186. #if kend > end:
  187. # raise ValueError('key past end')
  188. v, vend = self._loads(d, kend, end)
  189. if vend > end:
  190. raise ValueError('value past end')
  191. r[k] = v
  192. pos = vend
  193. return r, vend
  194. def enc_list(self, obj):
  195. r = ''.join(self.dumps(x) for x in obj)
  196. return _encodelen(len(r)) + r
  197. def dec_list(self, d, pos, end):
  198. r = []
  199. vend = pos
  200. while pos < end:
  201. v, vend = self._loads(d, pos, end)
  202. if vend > end:
  203. raise ValueError('load past end')
  204. r.append(v)
  205. pos = vend
  206. return r, vend
  207. enc_set = enc_list
  208. def dec_set(self, d, pos, end):
  209. r, end = self.dec_list(d, pos, end)
  210. return set(r), end
  211. @staticmethod
  212. def enc_bytes(obj):
  213. return _encodelen(len(obj)) + bytes(obj)
  214. @staticmethod
  215. def dec_bytes(d, pos, end):
  216. return d[pos:end], end
  217. @staticmethod
  218. def enc_unicode(obj):
  219. encobj = obj.encode('utf-8')
  220. return _encodelen(len(encobj)) + encobj
  221. def dec_unicode(self, d, pos, end):
  222. return d[pos:end].decode('utf-8'), end
  223. @staticmethod
  224. def enc_float(obj):
  225. s = math.copysign(1, obj)
  226. if math.isnan(obj):
  227. return _encodelen(1) + chr(0b01000010)
  228. elif math.isinf(obj):
  229. if s == 1:
  230. return _encodelen(1) + chr(0b01000000)
  231. else:
  232. return _encodelen(1) + chr(0b01000001)
  233. elif obj == 0:
  234. if s == 1:
  235. return _encodelen(0)
  236. else:
  237. return _encodelen(1) + chr(0b01000011)
  238. m, e = _splitfloat(obj)
  239. # Binary encoding
  240. val = 0x80
  241. if m < 0:
  242. val |= 0x40
  243. m = -m
  244. # Base 2
  245. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  246. if el > 2:
  247. raise ValueError('exponent too large')
  248. if e < 0:
  249. e += 256**el # convert negative to twos-complement
  250. v = el - 1
  251. encexp = _numtostr(e)
  252. val |= v
  253. r = chr(val) + encexp + _numtostr(m)
  254. return _encodelen(len(r)) + r
  255. def dec_float(self, d, pos, end):
  256. if pos == end:
  257. return float(0), end
  258. v = ord(d[pos])
  259. if v == 0b01000000:
  260. return float('inf'), end
  261. elif v == 0b01000001:
  262. return float('-inf'), end
  263. elif v == 0b01000010:
  264. return float('nan'), end
  265. elif v == 0b01000011:
  266. return float('-0'), end
  267. elif v & 0b110000:
  268. raise ValueError('base must be 2')
  269. elif v & 0b1100:
  270. raise ValueError('scaling factor must be 0')
  271. elif v & 0b11000000 == 0:
  272. raise ValueError('decimal encoding not supported')
  273. #elif v & 0b11000000 == 0b01000000:
  274. # raise ValueError('invalid encoding')
  275. if (v & 3) >= 2:
  276. raise ValueError('large exponents not supported')
  277. pexp = pos + 1
  278. eexp = pos + 1 + (v & 3) + 1
  279. exp = self.dec_int(d, pexp, eexp)[0]
  280. n = float(int(d[eexp:end].encode('hex'), 16))
  281. r = n * 2 ** exp
  282. if v & 0b1000000:
  283. r = -r
  284. return r, end
  285. def dumps(self, obj):
  286. '''Convert obj into an array of bytes.'''
  287. try:
  288. tf = self._typemap[type(obj)]
  289. except KeyError:
  290. if self.coerce is None:
  291. raise TypeError('unhandled object: %s' % `obj`)
  292. tf, obj = self.coerce(obj)
  293. fun = getattr(self, 'enc_%s' % tf)
  294. return self._typetag[tf] + fun(obj)
  295. def _loads(self, data, pos, end):
  296. tag = data[pos]
  297. l, b = _decodelen(data, pos + 1)
  298. if len(data) < pos + 1 + b + l:
  299. raise ValueError('string not long enough')
  300. # XXX - enforce that len(data) == end?
  301. end = pos + 1 + b + l
  302. t = self._tagmap[tag]
  303. fun = getattr(self, 'dec_%s' % t)
  304. return fun(data, pos + 1 + b, end)
  305. def enc_datetime(self, obj):
  306. ts = obj.strftime('%Y%m%d%H%M%S')
  307. if obj.microsecond:
  308. ts += ('.%06d' % obj.microsecond).rstrip('0')
  309. ts += 'Z'
  310. return _encodelen(len(ts)) + ts
  311. def dec_datetime(self, data, pos, end):
  312. ts = data[pos:end]
  313. if ts[-1] != 'Z':
  314. raise ValueError('last character must be Z')
  315. if '.' in ts:
  316. fstr = '%Y%m%d%H%M%S.%fZ'
  317. if ts.endswith('0Z'):
  318. raise ValueError('invalid trailing zeros')
  319. else:
  320. fstr = '%Y%m%d%H%M%SZ'
  321. return datetime.datetime.strptime(ts, fstr), end
  322. def loads(self, data, pos=0, end=None, consume=False):
  323. '''Load from data, starting at pos (optional), and ending
  324. at end (optional). If it is required to consume the
  325. whole string (not the default), set consume to True, and
  326. a ValueError will be raised if the string is not
  327. completely consumed. The second item in ValueError will
  328. be the possition that was the detected end.'''
  329. if end is None:
  330. end = len(data)
  331. r, e = self._loads(data, pos, end)
  332. if consume and e != end:
  333. raise ValueError('entire string not consumed', e)
  334. return r
  335. _coder = ASN1Coder()
  336. dumps = _coder.dumps
  337. loads = _coder.loads
  338. def deeptypecmp(obj, o):
  339. #print 'dtc:', `obj`, `o`
  340. if type(obj) != type(o):
  341. return False
  342. if type(obj) in (str, unicode):
  343. return True
  344. if type(obj) in (list, set):
  345. for i, j in zip(obj, o):
  346. if not deeptypecmp(i, j):
  347. return False
  348. if type(obj) in (dict,):
  349. itms = obj.items()
  350. itms.sort()
  351. nitms = o.items()
  352. nitms.sort()
  353. for (k, v), (nk, nv) in zip(itms, nitms):
  354. if not deeptypecmp(k, nk):
  355. return False
  356. if not deeptypecmp(v, nv):
  357. return False
  358. return True
  359. class Test_deeptypecmp(unittest.TestCase):
  360. def test_true(self):
  361. for i in ((1,1), ('sldkfj', 'sldkfj')
  362. ):
  363. self.assertTrue(deeptypecmp(*i))
  364. def test_false(self):
  365. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  366. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  367. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  368. ):
  369. self.assertFalse(deeptypecmp(*i))
  370. def genfailures(obj):
  371. s = dumps(obj)
  372. for i in xrange(len(s)):
  373. for j in (chr(x) for x in xrange(256)):
  374. ts = s[:i] + j + s[i + 1:]
  375. if ts == s:
  376. continue
  377. try:
  378. o = loads(ts, consume=True)
  379. if o != obj or not deeptypecmp(o, obj):
  380. raise ValueError
  381. except (ValueError, KeyError, IndexError, TypeError):
  382. pass
  383. else:
  384. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  385. class TestCode(unittest.TestCase):
  386. def test_primv(self):
  387. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  388. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  389. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  390. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  391. self.assertEqual(dumps(5), '020105'.decode('hex'))
  392. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  393. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  394. self.assertEqual(dumps(False), '010100'.decode('hex'))
  395. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  396. self.assertEqual(dumps(None), '0500'.decode('hex'))
  397. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  398. def test_fuzzing(self):
  399. # Make sure that when a failure is detected here, that it
  400. # gets added to test_invalids, so that this function may be
  401. # disabled.
  402. genfailures(float(1))
  403. genfailures([ 1, 2, 'sdlkfj' ])
  404. genfailures({ 1: 2, 5: 'sdlkfj' })
  405. genfailures(set([ 1, 2, 'sdlkfj' ]))
  406. genfailures(True)
  407. genfailures(datetime.datetime.utcnow())
  408. def test_invalids(self):
  409. # Add tests for base 8, 16 floats among others
  410. for v in [ '010101',
  411. '0903040001', # float scaling factor
  412. '0903840001', # float scaling factor
  413. '0903100001', # float base
  414. '0903900001', # float base
  415. '0903000001', # float decimal encoding
  416. '0903830001', # float exponent encoding
  417. '090b827fffcc0df505d0fa58f7', # float large exponent
  418. '3007020101020102040673646c6b666a', # list short string still valid
  419. 'e007020101020102020105040673646c6b666a', # dict short value still valid
  420. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  421. '181632303136303231373136343034372e3035343433367a', #datetime w/ lower z
  422. ]:
  423. self.assertRaises(ValueError, loads, v.decode('hex'))
  424. def test_invalid_floats(self):
  425. with mock.patch('math.frexp', return_value=(.87232, 1 << 23)):
  426. self.assertRaises(ValueError, dumps, 1.1)
  427. def test_consume(self):
  428. b = dumps(5)
  429. self.assertRaises(ValueError, loads, b + '398473',
  430. consume=True)
  431. # XXX - still possible that an internal data member
  432. # doesn't consume all
  433. # XXX - test that sets are ordered properly
  434. # XXX - test that dicts are ordered properly..
  435. def test_nan(self):
  436. s = dumps(float('nan'))
  437. v = loads(s)
  438. self.assertTrue(math.isnan(v))
  439. def test_cryptoutilasn1(self):
  440. '''Test DER sequences generated by Crypto.Util.asn1.'''
  441. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  442. ('\x05\x00', None),
  443. ('\x02\x03\x00\x96I', 38473),
  444. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  445. ]:
  446. self.assertEqual(loads(s), v)
  447. def test_longstrings(self):
  448. for i in (203, 65484):
  449. s = os.urandom(i)
  450. v = dumps(s)
  451. self.assertEqual(loads(v), s)
  452. def test_invaliddate(self):
  453. pass
  454. # XXX - add test to reject datetime w/ tzinfo, or that it
  455. # handles it properly
  456. def test_dumps(self):
  457. for i in [ None,
  458. True, False,
  459. -1, 0, 1, 255, 256, -255, -256,
  460. 23498732498723, -2398729387234,
  461. (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  462. float(0), float('-0'), float('inf'), float('-inf'),
  463. float(1.0), float(-1.0), float('353.3487'),
  464. float('2.38723873e+307'), float('2.387349e-317'),
  465. sys.float_info.max, sys.float_info.min,
  466. float('.15625'),
  467. 'weoifjwef',
  468. u'\U0001f4a9',
  469. [], [ 1,2,3 ],
  470. {}, { 5: 10, 'adfkj': 34 },
  471. set(), set((1,2,3)),
  472. set((1,'sjlfdkj', None, float('inf'))),
  473. datetime.datetime.utcnow(),
  474. datetime.datetime.utcnow().replace(microsecond=0),
  475. datetime.datetime.utcnow().replace(microsecond=1000),
  476. ]:
  477. s = dumps(i)
  478. o = loads(s)
  479. self.assertEqual(i, o)
  480. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1,
  481. 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  482. out = dumps(tobj)
  483. self.assertEqual(tobj, loads(out))
  484. def test_coerce(self):
  485. class Foo:
  486. pass
  487. class Bar:
  488. pass
  489. class Baz:
  490. pass
  491. def coerce(obj):
  492. if isinstance(obj, Foo):
  493. return 'list', obj.lst
  494. elif isinstance(obj, Baz):
  495. return 'bytes', obj.s
  496. raise TypeError('unknown type')
  497. ac = ASN1Coder(coerce)
  498. v = [1, 2, 3]
  499. o = Foo()
  500. o.lst = v
  501. self.assertEqual(ac.loads(ac.dumps(o)), v)
  502. self.assertRaises(TypeError, ac.dumps, Bar())
  503. v = u'oiejfd'
  504. o = Baz()
  505. o.s = v
  506. es = ac.dumps(o)
  507. self.assertEqual(ac.loads(es), v)
  508. self.assertIsInstance(es, bytes)
  509. self.assertRaises(TypeError, dumps, o)
  510. def test_loads(self):
  511. self.assertRaises(ValueError, loads, '\x00\x02\x00')