A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

511 lines
12 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle. It will automaticly do the correct thing if possible.
  4. #
  5. # This uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import datetime
  11. import math
  12. import os
  13. import pdb
  14. import sys
  15. import unittest
  16. __all__ = [ 'dumps', 'loads', 'ASN1Coder' ]
  17. def _numtostr(n):
  18. hs = '%x' % n
  19. if len(hs) & 1 == 1:
  20. hs = '0' + hs
  21. bs = hs.decode('hex')
  22. return bs
  23. def _encodelen(l):
  24. '''Takes l as a length value, and returns a byte string that
  25. represents l per ASN.1 rules.'''
  26. if l < 128:
  27. return chr(l)
  28. bs = _numtostr(l)
  29. return chr(len(bs) | 0x80) + bs
  30. def _decodelen(d, pos=0):
  31. '''Returns the length, and number of bytes required.'''
  32. odp = ord(d[pos])
  33. if odp < 128:
  34. return ord(d[pos]), 1
  35. else:
  36. l = odp & 0x7f
  37. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  38. class Test_codelen(unittest.TestCase):
  39. _testdata = [
  40. (2, '\x02'),
  41. (127, '\x7f'),
  42. (128, '\x81\x80'),
  43. (255, '\x81\xff'),
  44. (256, '\x82\x01\x00'),
  45. (65536-1, '\x82\xff\xff'),
  46. (65536, '\x83\x01\x00\x00'),
  47. ]
  48. def test_el(self):
  49. for i, j in self._testdata:
  50. self.assertEqual(_encodelen(i), j)
  51. self.assertEqual(_decodelen(j), (i, len(j)))
  52. def _splitfloat(f):
  53. m, e = math.frexp(f)
  54. # XXX - less than ideal
  55. while m != math.trunc(m):
  56. m *= 2
  57. e -= 1
  58. return m, e
  59. class TestSplitFloat(unittest.TestCase):
  60. def test_sf(self):
  61. for a, b in [ (0x2421, -32), (0x5382f, 238),
  62. (0x1fa8c3b094adf1, 971) ]:
  63. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  64. class ASN1Object:
  65. def __init__(self, tag):
  66. self._tag = tag
  67. class ASN1Coder(object):
  68. def __init__(self):
  69. pass
  70. _typemap = {
  71. bool: 'bool',
  72. dict: 'dict',
  73. float: 'float',
  74. int: 'int',
  75. list: 'list',
  76. long: 'int',
  77. set: 'set',
  78. str: 'bytes',
  79. type(None): 'null',
  80. unicode: 'unicode',
  81. #decimal.Decimal: 'float',
  82. datetime.datetime: 'datetime',
  83. #datetime.timedelta: 'timedelta',
  84. }
  85. _tagmap = {
  86. '\x01': 'bool',
  87. '\x02': 'int',
  88. '\x04': 'bytes',
  89. '\x05': 'null',
  90. '\x09': 'float',
  91. '\x0c': 'unicode',
  92. '\x18': 'datetime',
  93. '\x30': 'list',
  94. '\x31': 'set',
  95. '\xc0': 'dict',
  96. }
  97. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  98. @staticmethod
  99. def enc_int(obj):
  100. l = obj.bit_length()
  101. l += 1 # space for sign bit
  102. l = (l + 7) // 8
  103. if obj < 0:
  104. obj += 1 << (l * 8) # twos-complement conversion
  105. v = _numtostr(obj)
  106. if len(v) != l:
  107. # XXX - is this a problem for signed values?
  108. v = '\x00' + v # add sign octect
  109. return _encodelen(l) + v
  110. @staticmethod
  111. def dec_int(d, pos, end):
  112. if pos == end:
  113. return 0, end
  114. v = int(d[pos:end].encode('hex'), 16)
  115. av = 1 << ((end - pos) * 8 - 1) # sign bit
  116. if v > av:
  117. v -= av * 2 # twos-complement conversion
  118. return v, end
  119. @staticmethod
  120. def enc_bool(obj):
  121. return '\x01' + ('\xff' if obj else '\x00')
  122. def dec_bool(self, d, pos, end):
  123. v = self.dec_int(d, pos, end)[0]
  124. if v not in (-1, 0):
  125. raise ValueError('invalid bool value: %d' % v)
  126. return bool(v), end
  127. @staticmethod
  128. def enc_null(obj):
  129. return '\x00'
  130. @staticmethod
  131. def dec_null(d, pos, end):
  132. return None, end
  133. def enc_dict(self, obj):
  134. #it = list(obj.iteritems())
  135. #it.sort()
  136. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems())
  137. return _encodelen(len(r)) + r
  138. def dec_dict(self, d, pos, end):
  139. r = {}
  140. vend = pos
  141. while pos < end:
  142. k, kend = self._loads(d, pos, end)
  143. #if kend > end:
  144. # raise ValueError('key past end')
  145. v, vend = self._loads(d, kend, end)
  146. if vend > end:
  147. raise ValueError('value past end')
  148. r[k] = v
  149. pos = vend
  150. return r, vend
  151. def enc_set(self, obj):
  152. r = ''.join(self.dumps(x) for x in obj)
  153. return _encodelen(len(r)) + r
  154. def dec_set(self, d, pos, end):
  155. r, end = self.dec_list(d, pos, end)
  156. return set(r), end
  157. def enc_list(self, obj):
  158. r = ''.join(self.dumps(x) for x in obj)
  159. return _encodelen(len(r)) + r
  160. def dec_list(self, d, pos, end):
  161. r = []
  162. vend = pos
  163. while pos < end:
  164. v, vend = self._loads(d, pos, end)
  165. if vend > end:
  166. raise ValueError('load past end')
  167. r.append(v)
  168. pos = vend
  169. return r, vend
  170. @staticmethod
  171. def enc_bytes(obj):
  172. return _encodelen(len(obj)) + obj
  173. @staticmethod
  174. def dec_bytes(d, pos, end):
  175. return d[pos:end], end
  176. @staticmethod
  177. def enc_unicode(obj):
  178. encobj = obj.encode('utf-8')
  179. return _encodelen(len(encobj)) + encobj
  180. def dec_unicode(self, d, pos, end):
  181. return d[pos:end].decode('utf-8'), end
  182. @staticmethod
  183. def enc_float(obj):
  184. s = math.copysign(1, obj)
  185. if math.isnan(obj):
  186. return _encodelen(1) + chr(0b01000010)
  187. elif math.isinf(obj):
  188. if s == 1:
  189. return _encodelen(1) + chr(0b01000000)
  190. else:
  191. return _encodelen(1) + chr(0b01000001)
  192. elif obj == 0:
  193. if s == 1:
  194. return _encodelen(0)
  195. else:
  196. return _encodelen(1) + chr(0b01000011)
  197. m, e = _splitfloat(obj)
  198. # Binary encoding
  199. val = 0x80
  200. if m < 0:
  201. val |= 0x40
  202. m = -m
  203. # Base 2
  204. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  205. if e < 0:
  206. e += 256**el # convert negative to twos-complement
  207. if el > 3:
  208. v = 0x3
  209. encexp = _encodelen(el) + _numtostr(e)
  210. else:
  211. v = el - 1
  212. encexp = _numtostr(e)
  213. r = chr(val) + encexp + _numtostr(m)
  214. return _encodelen(len(r)) + r
  215. def dec_float(self, d, pos, end):
  216. if pos == end:
  217. return float(0), end
  218. v = ord(d[pos])
  219. if v == 0b01000000:
  220. return float('inf'), end
  221. elif v == 0b01000001:
  222. return float('-inf'), end
  223. elif v == 0b01000010:
  224. return float('nan'), end
  225. elif v == 0b01000011:
  226. return float('-0'), end
  227. elif v & 0b110000:
  228. raise ValueError('base must be 2')
  229. elif v & 0b1100:
  230. raise ValueError('scaling factor must be 0')
  231. elif v & 0b11000000 == 0:
  232. raise ValueError('decimal encoding not supported')
  233. #elif v & 0b11000000 == 0b01000000:
  234. # raise ValueError('invalid encoding')
  235. if v & 3 == 3:
  236. pexp = pos + 2
  237. explen = ord(d[pos + 1])
  238. if explen <= 3:
  239. raise ValueError('must use other length encoding')
  240. eexp = pos + 2 + explen
  241. else:
  242. pexp = pos + 1
  243. eexp = pos + 1 + (v & 3) + 1
  244. exp = self.dec_int(d, pexp, eexp)[0]
  245. n = float(int(d[eexp:end].encode('hex'), 16))
  246. r = n * 2 ** exp
  247. if v & 0b1000000:
  248. r = -r
  249. return r, end
  250. def dumps(self, obj):
  251. tf = self._typemap[type(obj)]
  252. fun = getattr(self, 'enc_%s' % tf)
  253. return self._typetag[tf] + fun(obj)
  254. def _loads(self, data, pos, end):
  255. tag = data[pos]
  256. l, b = _decodelen(data, pos + 1)
  257. if len(data) < pos + 1 + b + l:
  258. raise ValueError('string not long enough')
  259. # XXX - enforce that len(data) == end?
  260. end = pos + 1 + b + l
  261. t = self._tagmap[tag]
  262. fun = getattr(self, 'dec_%s' % t)
  263. return fun(data, pos + 1 + b, end)
  264. def enc_datetime(self, obj):
  265. ts = obj.strftime('%Y%m%d%H%M%S')
  266. if obj.microsecond:
  267. ts += ('.%06d' % obj.microsecond).rstrip('0')
  268. ts += 'Z'
  269. return _encodelen(len(ts)) + ts
  270. def dec_datetime(self, data, pos, end):
  271. ts = data[pos:end]
  272. if '.' in ts:
  273. fstr = '%Y%m%d%H%M%S.%fZ'
  274. if ts.endswith('0Z'):
  275. raise ValueError('invalid trailing zeros')
  276. else:
  277. fstr = '%Y%m%d%H%M%SZ'
  278. return datetime.datetime.strptime(ts, fstr), end
  279. def loads(self, data, pos=0, end=None, consume=False):
  280. if end is None:
  281. end = len(data)
  282. r, e = self._loads(data, pos, end)
  283. if consume and e != end:
  284. raise ValueError('entire string not consumed')
  285. return r
  286. def deeptypecmp(obj, o):
  287. #print 'dtc:', `obj`, `o`
  288. if type(obj) != type(o):
  289. return False
  290. if type(obj) in (str, unicode):
  291. return True
  292. if type(obj) in (list, set):
  293. for i, j in zip(obj, o):
  294. if not deeptypecmp(i, j):
  295. return False
  296. if type(obj) in (dict,):
  297. itms = obj.items()
  298. itms.sort()
  299. nitms = o.items()
  300. nitms.sort()
  301. for (k, v), (nk, nv) in zip(itms, nitms):
  302. if not deeptypecmp(k, nk):
  303. return False
  304. if not deeptypecmp(v, nv):
  305. return False
  306. return True
  307. class Test_deeptypecmp(unittest.TestCase):
  308. def test_true(self):
  309. for i in ((1,1), ('sldkfj', 'sldkfj')
  310. ):
  311. self.assertTrue(deeptypecmp(*i))
  312. def test_false(self):
  313. for i in (([[]], [{}]), ([1], ['str']), ([], set()),
  314. ({1: 2, 5: u'sdlkfj'}, {1: 2, 5: 'sdlkfj'}),
  315. ({1: 2, u'sdlkfj': 5}, {1: 2, 'sdlkfj': 5}),
  316. ):
  317. self.assertFalse(deeptypecmp(*i))
  318. def genfailures(obj):
  319. s = dumps(obj)
  320. for i in xrange(len(s)):
  321. for j in (chr(x) for x in xrange(256)):
  322. ts = s[:i] + j + s[i + 1:]
  323. if ts == s:
  324. continue
  325. try:
  326. o = loads(ts, consume=True)
  327. if o != obj or not deeptypecmp(o, obj):
  328. raise ValueError
  329. except (ValueError, KeyError, IndexError, TypeError):
  330. pass
  331. except Exception:
  332. raise
  333. else:
  334. raise AssertionError('uncaught modification: %s, byte %d, orig: %02x' % (ts.encode('hex'), i, ord(s[i])))
  335. _coder = ASN1Coder()
  336. dumps = _coder.dumps
  337. loads = _coder.loads
  338. class TestCode(unittest.TestCase):
  339. def test_primv(self):
  340. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  341. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  342. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  343. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  344. self.assertEqual(dumps(5), '020105'.decode('hex'))
  345. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  346. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  347. self.assertEqual(dumps(False), '010100'.decode('hex'))
  348. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  349. self.assertEqual(dumps(None), '0500'.decode('hex'))
  350. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  351. def test_fuzzing(self):
  352. genfailures(float(1))
  353. genfailures([ 1, 2, 'sdlkfj' ])
  354. genfailures({ 1: 2, 5: 'sdlkfj' })
  355. genfailures(set([ 1, 2, 'sdlkfj' ]))
  356. def test_consume(self):
  357. b = dumps(5)
  358. self.assertRaises(ValueError, loads, b + '398473', consume=True)
  359. # XXX - still possible that an internal data member
  360. # doesn't consume all
  361. # XXX - test that sets are ordered properly
  362. # XXX - test that dicts are ordered properly..
  363. def test_nan(self):
  364. s = dumps(float('nan'))
  365. v = loads(s)
  366. self.assertTrue(math.isnan(v))
  367. def test_invalids(self):
  368. # Add tests for base 8, 16 floats among others
  369. for v in [ '010101',
  370. '0903040001', # float scaling factor
  371. '0903840001', # float scaling factor
  372. '0903100001', # float base
  373. '0903900001', # float base
  374. '0903000001', # float decimal encoding
  375. '0903830001', # float exponent encoding
  376. '3007020101020102040673646c6b666a', # list short string still valid
  377. 'c007020101020102020105040673646c6b666a', # dict short value still valid
  378. '181632303136303231353038343031362e3539303839305a', #datetime w/ trailing zero
  379. ]:
  380. self.assertRaises(ValueError, loads, v.decode('hex'))
  381. def test_cryptoutilasn1(self):
  382. '''Test DER sequences generated by Crypto.Util.asn1.'''
  383. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  384. ('\x05\x00', None),
  385. ('\x02\x03\x00\x96I', 38473),
  386. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  387. ]:
  388. self.assertEqual(loads(s), v)
  389. def test_longstrings(self):
  390. for i in (203, 65484):
  391. s = os.urandom(i)
  392. v = dumps(s)
  393. self.assertEqual(loads(v), s)
  394. def test_invaliddate(self):
  395. pass
  396. # XXX - add test to reject datetime w/ tzinfo, or that it handles it
  397. # properly
  398. def test_dumps(self):
  399. for i in [ None,
  400. True, False,
  401. -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  402. float(0), float('-0'), float('inf'), float('-inf'), float(1.0), float(-1.0),
  403. float('353.3487'), float('2387.23873e492'), float('2387.348732e-392'),
  404. float('.15625'),
  405. 'weoifjwef',
  406. u'\U0001f4a9',
  407. [], [ 1,2,3 ],
  408. {}, { 5: 10, 'adfkj': 34 },
  409. set(), set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
  410. datetime.datetime.utcnow(), datetime.datetime.utcnow().replace(microsecond=0),
  411. datetime.datetime.utcnow().replace(microsecond=1000),
  412. ]:
  413. s = dumps(i)
  414. o = loads(s)
  415. self.assertEqual(i, o)
  416. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  417. out = dumps(tobj)
  418. self.assertEqual(tobj, loads(out))
  419. def test_loads(self):
  420. self.assertRaises(ValueError, loads, '\x00\x02\x00')