A pure Python ASN.1 library. Supports dict and sets.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

353 lines
7.7 KiB

  1. #!/usr/bin/env python
  2. # A Pure Python ASN.1 encoder/decoder w/ a calling interface in the spirit
  3. # of pickle. It will automaticly do the correct thing if possible.
  4. #
  5. # This uses a profile of ASN.1.
  6. #
  7. # All lengths must be specified. That is that End-of-contents octets
  8. # MUST not be used. The shorted form of length encoding MUST be used.
  9. # A longer length encoding MUST be rejected.
  10. import pdb
  11. import math
  12. import sys
  13. import unittest
  14. def _numtostr(n):
  15. hs = '%x' % n
  16. if len(hs) & 1 == 1:
  17. hs = '0' + hs
  18. bs = hs.decode('hex')
  19. return bs
  20. def _encodelen(l):
  21. '''Takes l as a length value, and returns a byte string that
  22. represents l per ASN.1 rules.'''
  23. if l < 128:
  24. return chr(l)
  25. bs = _numtostr(l)
  26. return chr(len(bs) | 0x80) + bs
  27. def _decodelen(d, pos=0):
  28. '''Returns the length, and number of bytes required.'''
  29. odp = ord(d[pos])
  30. if odp < 128:
  31. return ord(d[pos]), 1
  32. else:
  33. l = odp & 0x7f
  34. return int(d[pos + 1:pos + 1 + l].encode('hex'), 16), l + 1
  35. class Test_codelen(unittest.TestCase):
  36. _testdata = [
  37. (2, '\x02'),
  38. (127, '\x7f'),
  39. (128, '\x81\x80'),
  40. (255, '\x81\xff'),
  41. (256, '\x82\x01\x00'),
  42. (65536-1, '\x82\xff\xff'),
  43. (65536, '\x83\x01\x00\x00'),
  44. ]
  45. def test_el(self):
  46. for i, j in self._testdata:
  47. self.assertEqual(_encodelen(i), j)
  48. self.assertEqual(_decodelen(j), (i, len(j)))
  49. def _splitfloat(f):
  50. m, e = math.frexp(f)
  51. # XXX - less than ideal
  52. while m != math.trunc(m):
  53. m *= 2
  54. e -= 1
  55. return m, e
  56. class TestSplitFloat(unittest.TestCase):
  57. def test_sf(self):
  58. for a, b in [ (0x2421, -32), (0x5382f, 238),
  59. (0x1fa8c3b094adf1, 971) ]:
  60. self.assertEqual(_splitfloat(a * 2**b), (a, b))
  61. class ASN1Object:
  62. def __init__(self, tag):
  63. self._tag = tag
  64. class ASN1Coder(object):
  65. def __init__(self):
  66. pass
  67. _typemap = {
  68. bool: 'bool',
  69. dict: 'dict',
  70. float: 'float',
  71. int: 'int',
  72. list: 'list',
  73. long: 'int',
  74. set: 'set',
  75. str: 'bytes',
  76. type(None): 'null',
  77. unicode: 'unicode',
  78. }
  79. _tagmap = {
  80. '\x01': 'bool',
  81. '\x02': 'int',
  82. '\x04': 'bytes',
  83. '\x05': 'null',
  84. '\x09': 'float',
  85. '\x0c': 'unicode',
  86. '\x30': 'list',
  87. '\x31': 'set',
  88. '\xc0': 'dict',
  89. #'xxx': 'datetime',
  90. }
  91. _typetag = dict((v, k) for k, v in _tagmap.iteritems())
  92. @staticmethod
  93. def enc_int(obj):
  94. l = obj.bit_length()
  95. l += 1 # space for sign bit
  96. l = (l + 7) // 8
  97. if obj < 0:
  98. obj += 1 << (l * 8) # twos-complement conversion
  99. v = _numtostr(obj)
  100. if len(v) != l:
  101. # XXX - is this a problem for signed values?
  102. v = '\x00' + v # add sign octect
  103. return _encodelen(l) + v
  104. @staticmethod
  105. def dec_int(d, pos, end):
  106. if pos == end:
  107. return 0, end
  108. v = int(d[pos:end].encode('hex'), 16)
  109. av = 1 << ((end - pos) * 8 - 1) # sign bit
  110. if v > av:
  111. v -= av * 2 # twos-complement conversion
  112. return v, end
  113. @staticmethod
  114. def enc_bool(obj):
  115. return '\x01' + chr(obj)
  116. def dec_bool(self, d, pos, end):
  117. return bool(self.dec_int(d, pos, end)[0]), end
  118. @staticmethod
  119. def enc_null(obj):
  120. return '\x00'
  121. @staticmethod
  122. def dec_null(d, pos, end):
  123. return None, end
  124. def enc_dict(self, obj):
  125. #it = list(obj.iteritems())
  126. #it.sort()
  127. r = ''.join(self.dumps(k) + self.dumps(v) for k, v in obj.iteritems())
  128. return _encodelen(len(r)) + r
  129. def dec_dict(self, d, pos, end):
  130. r = {}
  131. while pos < end:
  132. k, kend = self._loads(d, pos, end)
  133. v, vend = self._loads(d, kend, end)
  134. r[k] = v
  135. pos = vend
  136. return r, vend
  137. def enc_set(self, obj):
  138. r = ''.join(self.dumps(x) for x in obj)
  139. return _encodelen(len(r)) + r
  140. def dec_set(self, d, pos, end):
  141. r, end = self.dec_list(d, pos, end)
  142. return set(r), end
  143. def enc_list(self, obj):
  144. r = ''.join(self.dumps(x) for x in obj)
  145. return _encodelen(len(r)) + r
  146. def dec_list(self, d, pos, end):
  147. r = []
  148. while pos < end:
  149. v, vend = self._loads(d, pos, end)
  150. r.append(v)
  151. pos = vend
  152. return r, vend
  153. @staticmethod
  154. def enc_bytes(obj):
  155. return _encodelen(len(obj)) + obj
  156. @staticmethod
  157. def dec_bytes(d, pos, end):
  158. return d[pos:end], end
  159. @staticmethod
  160. def enc_unicode(obj):
  161. encobj = obj.encode('utf-8')
  162. return _encodelen(len(encobj)) + encobj
  163. def dec_unicode(self, d, pos, end):
  164. return d[pos:end].decode('utf-8'), end
  165. @staticmethod
  166. def enc_float(obj):
  167. s = math.copysign(1, obj)
  168. if math.isnan(obj):
  169. return _encodelen(1) + chr(0b01000010)
  170. elif math.isinf(obj):
  171. if s == 1:
  172. return _encodelen(1) + chr(0b01000000)
  173. else:
  174. return _encodelen(1) + chr(0b01000001)
  175. elif obj == 0:
  176. if s == 1:
  177. return _encodelen(0)
  178. else:
  179. return _encodelen(1) + chr(0b01000011)
  180. m, e = _splitfloat(obj)
  181. # Binary encoding
  182. val = 0x80
  183. if m < 0:
  184. val |= 0x40
  185. m = -m
  186. # Base 2
  187. # XXX - negative e
  188. el = (e.bit_length() + 7) // 8
  189. if el > 3:
  190. v = 0x3
  191. encexp = _encodelen(el) + _numtostr(e)
  192. else:
  193. v = el - 1
  194. encexp = _numtostr(e)
  195. return chr(val) + encexp + _numtostr(m)
  196. @staticmethod
  197. def dec_float(d, pos, end):
  198. if pos == end:
  199. return float(0), end
  200. v = ord(d[pos])
  201. if v == 0b01000000:
  202. return float('inf'), end
  203. elif v == 0b01000001:
  204. return float('-inf'), end
  205. elif v == 0b01000010:
  206. return float('nan'), end
  207. elif v == 0b01000011:
  208. return float('-0'), end
  209. #elif v & 0b11000000 == 0b01000000:
  210. # raise ValueError('invalid encoding')
  211. raise NotImplementedError
  212. def dumps(self, obj):
  213. tf = self._typemap[type(obj)]
  214. fun = getattr(self, 'enc_%s' % tf)
  215. return self._typetag[tf] + fun(obj)
  216. def _loads(self, data, pos, end):
  217. tag = data[pos]
  218. l, b = _decodelen(data, pos + 1)
  219. if len(data) < pos + 1 + b + l:
  220. raise ValueError('string not long enough')
  221. # XXX - enforce that len(data) == end?
  222. end = pos + 1 + b + l
  223. t = self._tagmap[tag]
  224. fun = getattr(self, 'dec_%s' % t)
  225. return fun(data, pos + 1 + b, end)
  226. def loads(self, data, pos=0, end=None, consume=False):
  227. if end is None:
  228. end = len(data)
  229. r, e = self._loads(data, pos, end)
  230. if consume and e != end:
  231. raise ValueError('entire string not consumed')
  232. return r
  233. _coder = ASN1Coder()
  234. dumps = _coder.dumps
  235. loads = _coder.loads
  236. class TestCode(unittest.TestCase):
  237. def test_primv(self):
  238. self.assertEqual(dumps(-257), '0202feff'.decode('hex'))
  239. self.assertEqual(dumps(-256), '0202ff00'.decode('hex'))
  240. self.assertEqual(dumps(-255), '0202ff01'.decode('hex'))
  241. self.assertEqual(dumps(-1), '0201ff'.decode('hex'))
  242. self.assertEqual(dumps(5), '020105'.decode('hex'))
  243. self.assertEqual(dumps(128), '02020080'.decode('hex'))
  244. self.assertEqual(dumps(256), '02020100'.decode('hex'))
  245. self.assertEqual(dumps(False), '010100'.decode('hex'))
  246. self.assertEqual(dumps(True), '010101'.decode('hex'))
  247. self.assertEqual(dumps(None), '0500'.decode('hex'))
  248. def test_consume(self):
  249. b = dumps(5)
  250. self.assertRaises(ValueError, loads, b + '398473', consume=True)
  251. # XXX - still possible that an internal data member
  252. # doesn't consume all
  253. def test_nan(self):
  254. s = dumps(float('nan'))
  255. v = loads(s)
  256. self.assertTrue(math.isnan(v))
  257. def test_cryptoutilasn1(self):
  258. '''Test DER sequences generated by Crypto.Util.asn1.'''
  259. for s, v in [ ('\x02\x03$\x8a\xf9', 2394873),
  260. ('\x05\x00', None),
  261. ('\x02\x03\x00\x96I', 38473),
  262. ]:
  263. self.assertEqual(loads(s), v)
  264. def test_dumps(self):
  265. for i in [ None,
  266. True, False,
  267. -1, 0, 1, 255, 256, -255, -256, 23498732498723, -2398729387234, (1<<2383) + 23984734, (-1<<1983) + 23984723984,
  268. float(0), float('-0'), float('inf'), float('-inf'),
  269. 'weoifjwef',
  270. u'\U0001f4a9',
  271. set((1,2,3)), set((1,'sjlfdkj', None, float('inf'))),
  272. ]:
  273. s = dumps(i)
  274. o = loads(s)
  275. self.assertEqual(i, o)
  276. tobj = { 1: 'dflkj', 5: u'sdlkfj', 'float': 1, 'largeint': 1<<342, 'list': [ 1, 2, u'str', 'str' ] }
  277. out = dumps(tobj)
  278. self.assertEqual(tobj, loads(out))
  279. def test_loads(self):
  280. self.assertRaises(ValueError, loads, '\x00\x02\x00')
  281. if __name__ == '__main__':
  282. pass