You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

354 lines
12 KiB

  1. #
  2. # This file is part of usb_protocol.
  3. #
  4. """ Convenience emitters for simple, standard descriptors. """
  5. import unittest
  6. from contextlib import contextmanager
  7. from .. import emitter_for_format
  8. from ..descriptor import ComplexDescriptorEmitter
  9. from ...types.descriptors.standard import \
  10. DeviceDescriptor, StringDescriptor, EndpointDescriptor, DeviceQualifierDescriptor, \
  11. ConfigurationDescriptor, InterfaceDescriptor, StandardDescriptorNumbers
  12. # Create our basic emitters...
  13. DeviceDescriptorEmitter = emitter_for_format(DeviceDescriptor)
  14. StringDescriptorEmitter = emitter_for_format(StringDescriptor)
  15. EndpointDescriptorEmitter = emitter_for_format(EndpointDescriptor)
  16. DeviceQualifierDescriptor = emitter_for_format(DeviceQualifierDescriptor)
  17. # ... convenience functions ...
  18. def get_string_descriptor(string):
  19. """ Generates a string descriptor for the relevant string. """
  20. emitter = StringDescriptorEmitter()
  21. emitter.bString = string
  22. return emitter.emit()
  23. # ... and complex emitters.
  24. class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
  25. """ Emitter that creates an InterfaceDescriptor. """
  26. DESCRIPTOR_FORMAT = InterfaceDescriptor
  27. @contextmanager
  28. def EndpointDescriptor(self):
  29. """ Context manager that allows addition of a subordinate endpoint descriptor.
  30. It can be used with a `with` statement; and yields an EndpointDesriptorEmitter
  31. that can be populated:
  32. with interface.EndpointDescriptor() as d:
  33. d.bEndpointAddress = 0x01
  34. d.bmAttributes = 0x80
  35. d.wMaxPacketSize = 64
  36. d.bInterval = 0
  37. This adds the relevant descriptor, automatically.
  38. """
  39. descriptor = EndpointDescriptorEmitter()
  40. yield descriptor
  41. self.add_subordinate_descriptor(descriptor)
  42. def _pre_emit(self):
  43. # Count our endpoints, and update our internal count.
  44. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
  45. # Ensure that our interface string is an index, if we can.
  46. if self._collection and hasattr(self, 'iInterface'):
  47. self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)
  48. class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
  49. """ Emitter that creates a configuration descriptor. """
  50. DESCRIPTOR_FORMAT = ConfigurationDescriptor
  51. @contextmanager
  52. def InterfaceDescriptor(self):
  53. """ Context manager that allows addition of a subordinate interface descriptor.
  54. It can be used with a `with` statement; and yields an InterfaceDescriptorEmitter
  55. that can be populated:
  56. with interface.InterfaceDescriptor() as d:
  57. d.bInterfaceNumber = 0x01
  58. [snip]
  59. This adds the relevant descriptor, automatically. Note that populating derived
  60. fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
  61. """
  62. descriptor = InterfaceDescriptorEmitter(collection=self._collection)
  63. yield descriptor
  64. self.add_subordinate_descriptor(descriptor)
  65. def _pre_emit(self):
  66. # Count our interfaces.
  67. self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE]
  68. # Figure out our total length.
  69. subordinate_length = sum(len(sub) for sub in self._subordinates)
  70. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  71. # Ensure that our configuration string is an index, if we can.
  72. if self._collection and hasattr(self, 'iConfiguration'):
  73. self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)
  74. class DeviceDescriptorCollection:
  75. """ Object that builds a full collection of descriptors related to a given USB device. """
  76. def __init__(self):
  77. # Create our internal descriptor tracker.
  78. # Keys are a tuple of (type, index).
  79. self._descriptors = {}
  80. # Track string descriptors as they're created.
  81. self._next_string_index = 1
  82. self._index_for_string = {}
  83. def ensure_string_field_is_index(self, field_value):
  84. """ Processes the given field value; if it's not an string index, converts it to one.
  85. Non-index-fields are converted to indices using `get_index_for_string`, which automatically
  86. adds the relevant fields to our string descriptor collection.
  87. """
  88. if isinstance(field_value, str):
  89. return self.get_index_for_string(field_value)
  90. else:
  91. return field_value
  92. def get_index_for_string(self, string):
  93. """ Returns an string descriptor index for the given string.
  94. If a string descriptor already exists for the given string, its index is
  95. returned. Otherwise, a string descriptor is created.
  96. """
  97. # If we already have a descriptor for this string, return it.
  98. if string in self._index_for_string:
  99. return self._index_for_string[string]
  100. # Otherwise, create one:
  101. # Allocate an index...
  102. index = self._next_string_index
  103. self._index_for_string[string] = index
  104. self._next_string_index += 1
  105. # ... store our string descriptor with it ...
  106. identifier = StandardDescriptorNumbers.STRING, index
  107. self._descriptors[identifier] = get_string_descriptor(string)
  108. # ... and return our index.
  109. return index
  110. def add_descriptor(self, descriptor, index=0):
  111. """ Adds a descriptor to our collection.
  112. Parameters:
  113. descriptor -- The descriptor to be added.
  114. index -- The index of the relevant descriptor. Defaults to 0.
  115. """
  116. # If this is an emitter rather than a descriptor itself, convert it.
  117. if hasattr(descriptor, 'emit'):
  118. descriptor = descriptor.emit()
  119. # Figure out the identifier (type + index) for this descriptor...
  120. descriptor_type = descriptor[1]
  121. identifier = descriptor_type, index
  122. # ... and store it.
  123. self._descriptors[identifier] = descriptor
  124. @contextmanager
  125. def DeviceDescriptor(self):
  126. """ Context manager that allows addition of a device descriptor.
  127. It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
  128. that can be populated:
  129. with collection.DeviceDescriptor() as d:
  130. d.idVendor = 0xabcd
  131. d.idProduct = 0x1234
  132. [snip]
  133. This adds the relevant descriptor, automatically.
  134. """
  135. descriptor = DeviceDescriptorEmitter()
  136. yield descriptor
  137. # If we have any string fields, ensure that they're indices before continuing.
  138. for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
  139. if hasattr(descriptor, field):
  140. value = getattr(descriptor, field)
  141. index = self.ensure_string_field_is_index(value)
  142. setattr(descriptor, field, index)
  143. self.add_descriptor(descriptor)
  144. @contextmanager
  145. def ConfigurationDescriptor(self):
  146. """ Context manager that allows addition of a configuration descriptor.
  147. It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
  148. that can be populated:
  149. with collection.ConfigurationDescriptor() as d:
  150. d.bConfigurationValue = 1
  151. [snip]
  152. This adds the relevant descriptor, automatically. Note that populating derived
  153. fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
  154. """
  155. descriptor = ConfigurationDescriptorEmitter()
  156. yield descriptor
  157. self.add_descriptor(descriptor)
  158. def get_descriptor_bytes(self, type_number: int, index: int = 0):
  159. """ Returns the raw, binary descriptor for a given descriptor type/index.
  160. Parmeters:
  161. type_number -- The descriptor type number.
  162. index -- The index of the relevant descriptor, if relevant.
  163. """
  164. return self._descriptors[(type_number, index)]
  165. def __iter__(self):
  166. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  167. return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())
  168. class EmitterTests(unittest.TestCase):
  169. def test_string_emitter(self):
  170. emitter = StringDescriptorEmitter()
  171. emitter.bString = "Hello"
  172. self.assertEqual(emitter.emit(), b"\x0C\x03H\0e\0l\0l\0o\0")
  173. def test_string_emitter_function(self):
  174. self.assertEqual(get_string_descriptor("Hello"), b"\x0C\x03H\0e\0l\0l\0o\0")
  175. def test_configuration_emitter(self):
  176. descriptor = bytes([
  177. # config descriptor
  178. 12, # length
  179. 2, # type
  180. 25, 00, # total length
  181. 1, # num interfaces
  182. 1, # configuration number
  183. 0, # config string
  184. 0x80, # attributes
  185. 250, # max power
  186. # interface descriptor
  187. 9, # length
  188. 4, # type
  189. 0, # number
  190. 0, # alternate
  191. 1, # num endpoints
  192. 0xff, # class
  193. 0xff, # subclass
  194. 0xff, # protocol
  195. 0, # string
  196. # endpoint descriptor
  197. 7, # length
  198. 5, # type
  199. 0x01, # address
  200. 2, # attributes
  201. 64, 0, # max packet size
  202. 255, # interval
  203. ])
  204. # Create a trivial configuration descriptor...
  205. emitter = ConfigurationDescriptorEmitter()
  206. with emitter.InterfaceDescriptor() as interface:
  207. interface.bInterfaceNumber = 0
  208. with interface.EndpointDescriptor() as endpoint:
  209. endpoint.bEndpointAddress = 1
  210. # ... and validate that it maches our reference descriptor.
  211. binary = emitter.emit()
  212. self.assertEqual(len(binary), len(descriptor))
  213. def test_descriptor_collection(self):
  214. collection = DeviceDescriptorCollection()
  215. with collection.DeviceDescriptor() as d:
  216. d.idVendor = 0xdead
  217. d.idProduct = 0xbeef
  218. d.bNumConfigurations = 1
  219. d.iManufacturer = "Test Company"
  220. d.iProduct = "Test Product"
  221. with collection.ConfigurationDescriptor() as c:
  222. c.bConfigurationValue = 1
  223. with c.InterfaceDescriptor() as i:
  224. i.bInterfaceNumber = 1
  225. with i.EndpointDescriptor() as e:
  226. e.bEndpointAddress = 0x81
  227. with i.EndpointDescriptor() as e:
  228. e.bEndpointAddress = 0x01
  229. results = list(collection)
  230. # We should wind up with four descriptor entries, as our endpoint/interface descriptors are
  231. # included in our configuration descriptor.
  232. self.assertEqual(len(results), 4)
  233. # Manufacturer / product string.
  234. self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
  235. self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)
  236. # Device descriptor.
  237. self.assertIn((1, 0, b'\x0f\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)
  238. # Configuration descriptor, with subordinates.
  239. self.assertIn((2, 0, b'\r\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)
  240. if __name__ == "__main__":
  241. unittest.main()