A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

393 lines
8.7 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle. It will automaticly do the correct thing if possible.
  4. #
  5. # This uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import math
  11. import os
  12. import pdb
  13. import sys
  14. import unittest
  15. def _numtostr(n):
  16. hs = '%x' % n
  17. if len(hs) & 1 == 1:
  18. hs = '0' + hs
  19. bs = hs.decode('hex')
  20. return bs
  21. def _encodelen(l):
  22. '''Takes l as a length value, and returns a byte string that
  23. represents l per ASN.1 rules.'''
  24. if l < 128:
  25. return chr(l)
  26. bs = _numtostr(l)
  27. return chr(len(bs) | 0x80) + bs
  28. def _decodelen(d, pos=0):
  29. '''Returns the length, and number of bytes required.'''
  30. odp = ord(d[pos])
  31. if odp < 128:
  32. return ord(d[pos]), 1
  33. else:
  34. l = odp & 0x7f
  35. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  36. class Test_codelen(unittest.TestCase):
  37. _testdata = [
  38. (2, '\x02'),
  39. (127, '\x7f'),
  40. (128, '\x81\x80'),
  41. (255, '\x81\xff'),
  42. (256, '\x82\x01\x00'),
  43. (65536-1, '\x82\xff\xff'),
  44. (65536, '\x83\x01\x00\x00'),
  45. ]
  46. def test_el(self):
  47. for i, j in self._testdata:
  48. self.assertEqual(_encodelen(i), j)
  49. self.assertEqual(_decodelen(j), (i, len(j)))
  50. def _splitfloat(f):
  51. m, e = math.frexp(f)
  52. # XXX - less than ideal
  53. while m != math.trunc(m):
  54. m *= 2
  55. e -= 1
  56. return m, e
  57. class TestSplitFloat(unittest.TestCase):
  58. def test_sf(self):
  59. for a, b in [ (0x2421, -32), (0x5382f, 238),
  60. (0x1fa8c3b094adf1, 971) ]:
  61. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  62. class ASN1Object:
  63. def __init__(self, tag):
  64. self._tag = tag
  65. class ASN1Coder(object):
  66. def __init__(self):
  67. pass
  68. _typemap = {
  69. bool: 'bool',
  70. dict: 'dict',
  71. float: 'float',
  72. int: 'int',
  73. list: 'list',
  74. long: 'int',
  75. set: 'set',
  76. str: 'bytes',
  77. type(None): 'null',
  78. unicode: 'unicode',
  79. }
  80. _tagmap = {
  81. '\x01': 'bool',
  82. '\x02': 'int',
  83. '\x04': 'bytes',
  84. '\x05': 'null',
  85. '\x09': 'float',
  86. '\x0c': 'unicode',
  87. '\x30': 'list',
  88. '\x31': 'set',
  89. '\xc0': 'dict',
  90. #'xxx': 'datetime',
  91. }
  92. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  93. @staticmethod
  94. def enc_int(obj):
  95. l = obj.bit_length()
  96. l += 1 # space for sign bit
  97. l = (l + 7) // 8
  98. if obj < 0:
  99. obj += 1 << (l * 8) # twos-complement conversion
  100. v = _numtostr(obj)
  101. if len(v) != l:
  102. # XXX - is this a problem for signed values?
  103. v = '\x00' + v # add sign octect
  104. return _encodelen(l) + v
  105. @staticmethod
  106. def dec_int(d, pos, end):
  107. if pos == end:
  108. return 0, end
  109. v = int(d[pos:end].encode('hex'), 16)
  110. av = 1 << ((end - pos) * 8 - 1) # sign bit
  111. if v > av:
  112. v -= av * 2 # twos-complement conversion
  113. return v, end
  114. @staticmethod
  115. def enc_bool(obj):
  116. return '\x01' + ('\xff' if obj else '\x00')
  117. def dec_bool(self, d, pos, end):
  118. v = self.dec_int(d, pos, end)[0]
  119. if v not in (-1, 0):
  120. raise ValueError('invalid bool value: %d' % v)
  121. return bool(v), end
  122. @staticmethod
  123. def enc_null(obj):
  124. return '\x00'
  125. @staticmethod
  126. def dec_null(d, pos, end):
  127. return None, end
  128. def enc_dict(self, obj):
  129. #it = list(obj.iteritems())
  130. #it.sort()
  131. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems())
  132. return _encodelen(len(r)) + r
  133. def dec_dict(self, d, pos, end):
  134. r = {}
  135. while pos < end:
  136. k, kend = self._loads(d, pos, end)
  137. v, vend = self._loads(d, kend, end)
  138. r[k] = v
  139. pos = vend
  140. return r, vend
  141. def enc_set(self, obj):
  142. r = ''.join(self.dumps(x) for x in obj)
  143. return _encodelen(len(r)) + r
  144. def dec_set(self, d, pos, end):
  145. r, end = self.dec_list(d, pos, end)
  146. return set(r), end
  147. def enc_list(self, obj):
  148. r = ''.join(self.dumps(x) for x in obj)
  149. return _encodelen(len(r)) + r
  150. def dec_list(self, d, pos, end):
  151. r = []
  152. while pos < end:
  153. v, vend = self._loads(d, pos, end)
  154. r.append(v)
  155. pos = vend
  156. return r, vend
  157. @staticmethod
  158. def enc_bytes(obj):
  159. return _encodelen(len(obj)) + obj
  160. @staticmethod
  161. def dec_bytes(d, pos, end):
  162. return d[pos:end], end
  163. @staticmethod
  164. def enc_unicode(obj):
  165. encobj = obj.encode('utf-8')
  166. return _encodelen(len(encobj)) + encobj
  167. def dec_unicode(self, d, pos, end):
  168. return d[pos:end].decode('utf-8'), end
  169. @staticmethod
  170. def enc_float(obj):
  171. s = math.copysign(1, obj)
  172. if math.isnan(obj):
  173. return _encodelen(1) + chr(0b01000010)
  174. elif math.isinf(obj):
  175. if s == 1:
  176. return _encodelen(1) + chr(0b01000000)
  177. else:
  178. return _encodelen(1) + chr(0b01000001)
  179. elif obj == 0:
  180. if s == 1:
  181. return _encodelen(0)
  182. else:
  183. return _encodelen(1) + chr(0b01000011)
  184. m, e = _splitfloat(obj)
  185. # Binary encoding
  186. val = 0x80
  187. if m < 0:
  188. val |= 0x40
  189. m = -m
  190. # Base 2
  191. el = (e.bit_length() + 7 + 1) // 8 # + 1 is sign bit
  192. if e < 0:
  193. e += 256**el # convert negative to twos-complement
  194. if el > 3:
  195. v = 0x3
  196. encexp = _encodelen(el) + _numtostr(e)
  197. else:
  198. v = el - 1
  199. encexp = _numtostr(e)
  200. r = chr(val) + encexp + _numtostr(m)
  201. return _encodelen(len(r)) + r
  202. def dec_float(self, d, pos, end):
  203. if pos == end:
  204. return float(0), end
  205. v = ord(d[pos])
  206. if v == 0b01000000:
  207. return float('inf'), end
  208. elif v == 0b01000001:
  209. return float('-inf'), end
  210. elif v == 0b01000010:
  211. return float('nan'), end
  212. elif v == 0b01000011:
  213. return float('-0'), end
  214. #elif v & 0b11000000 == 0b01000000:
  215. # raise ValueError('invalid encoding')
  216. if not (v & 0b10000000):
  217. raise NotImplementedError
  218. if v & 3 == 3:
  219. pexp = pos + 2
  220. eexp = pos + 2 + ord(d[pos + 1])
  221. else:
  222. pexp = pos + 1
  223. eexp = pos + 1 + (v & 3) + 1
  224. exp = self.dec_int(d, pexp, eexp)[0]
  225. n = float(int(d[eexp:end].encode('hex'), 16))
  226. r = n * 2 ** exp
  227. if v & 0b1000000:
  228. r = -r
  229. return r, end
  230. def dumps(self, obj):
  231. tf = self._typemap[type(obj)]
  232. fun = getattr(self, 'enc_%s' % tf)
  233. return self._typetag[tf] + fun(obj)
  234. def _loads(self, data, pos, end):
  235. tag = data[pos]
  236. l, b = _decodelen(data, pos + 1)
  237. if len(data) < pos + 1 + b + l:
  238. raise ValueError('string not long enough')
  239. # XXX - enforce that len(data) == end?
  240. end = pos + 1 + b + l
  241. t = self._tagmap[tag]
  242. fun = getattr(self, 'dec_%s' % t)
  243. return fun(data, pos + 1 + b, end)
  244. def loads(self, data, pos=0, end=None, consume=False):
  245. if end is None:
  246. end = len(data)
  247. r, e = self._loads(data, pos, end)
  248. if consume and e != end:
  249. raise ValueError('entire string not consumed')
  250. return r
  251. _coder = ASN1Coder()
  252. dumps = _coder.dumps
  253. loads = _coder.loads
  254. class TestCode(unittest.TestCase):
  255. def test_primv(self):
  256. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  257. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  258. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  259. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  260. self.assertEqual(dumps(5), '020105'.decode('hex'))
  261. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  262. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  263. self.assertEqual(dumps(False), '010100'.decode('hex'))
  264. self.assertEqual(dumps(True), '0101ff'.decode('hex'))
  265. self.assertEqual(dumps(None), '0500'.decode('hex'))
  266. self.assertEqual(dumps(.15625), '090380fb05'.decode('hex'))
  267. def test_consume(self):
  268. b = dumps(5)
  269. self.assertRaises(ValueError, loads, b + '398473', consume=True)
  270. # XXX - still possible that an internal data member
  271. # doesn't consume all
  272. # XXX - test that sets are ordered properly
  273. # XXX - test that dicts are ordered properly..
  274. def test_nan(self):
  275. s = dumps(float('nan'))
  276. v = loads(s)
  277. self.assertTrue(math.isnan(v))
  278. def test_invalids(self):
  279. # Add tests for base 8, 16 floats among others
  280. for v in [ '010101', ]:
  281. self.assertRaises(ValueError, loads, v.decode('hex'))
  282. def test_cryptoutilasn1(self):
  283. '''Test DER sequences generated by Crypto.Util.asn1.'''
  284. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  285. ('\x05\x00', None),
  286. ('\x02\x03\x00\x96I', 38473),
  287. ('\x04\x81\xc8' + '\x00' * 200, '\x00' * 200),
  288. ]:
  289. self.assertEqual(loads(s), v)
  290. def test_longstrings(self):
  291. for i in (203, 65484):
  292. s = os.urandom(i)
  293. v = dumps(s)
  294. self.assertEqual(loads(v), s)
  295. def test_dumps(self):
  296. for i in [ None,
  297. True, False,
  298. -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  299. float(0), float('-0'), float('inf'), float('-inf'), float(1.0), float(-1.0),
  300. float('353.3487'), float('2387.23873e492'), float('2387.348732e-392'),
  301. float('.15625'),
  302. 'weoifjwef',
  303. u'\U0001f4a9',
  304. set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
  305. ]:
  306. s = dumps(i)
  307. o = loads(s)
  308. self.assertEqual(i, o)
  309. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  310. out = dumps(tobj)
  311. self.assertEqual(tobj, loads(out))
  312. def test_loads(self):
  313. self.assertRaises(ValueError, loads, '\x00\x02\x00')