You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

571 lines
19 KiB

  1. #
  2. # This file is part of usb_protocol.
  3. #
  4. """ Convenience emitters for simple, standard descriptors. """
  5. import unittest
  6. from contextlib import contextmanager
  7. from .. import emitter_for_format
  8. from ..descriptor import ComplexDescriptorEmitter
  9. from ...types import LanguageIDs
  10. from ...types.descriptors.standard import *
  11. # Create our basic emitters...
  12. DeviceDescriptorEmitter = emitter_for_format(DeviceDescriptor)
  13. StringDescriptorEmitter = emitter_for_format(StringDescriptor)
  14. StringLanguageDescriptorEmitter = emitter_for_format(StringLanguageDescriptor)
  15. DeviceQualifierDescriptor = emitter_for_format(DeviceQualifierDescriptor)
  16. # ... our basic superspeed emitters ...
  17. USB2ExtensionDescriptorEmitter = emitter_for_format(USB2ExtensionDescriptor)
  18. SuperSpeedUSBDeviceCapabilityDescriptorEmitter = emitter_for_format(SuperSpeedUSBDeviceCapabilityDescriptor)
  19. SuperSpeedEndpointCompanionDescriptorEmitter = emitter_for_format(SuperSpeedEndpointCompanionDescriptor)
  20. # ... convenience functions ...
  21. def get_string_descriptor(string):
  22. """ Generates a string descriptor for the relevant string. """
  23. emitter = StringDescriptorEmitter()
  24. emitter.bString = string
  25. return emitter.emit()
  26. # ... and complex emitters.
  27. class EndpointDescriptorEmitter(ComplexDescriptorEmitter):
  28. """ Emitter that creates an InterfaceDescriptor. """
  29. DESCRIPTOR_FORMAT = EndpointDescriptor
  30. @contextmanager
  31. def SuperSpeedCompanion(self):
  32. """ Context manager that allows addition of a SuperSpeed Companion to this endpoint descriptor.
  33. It can be used with a `with` statement; and yields an SuperSpeedEndpointCompanionDescriptorEmitter
  34. that can be populated:
  35. with endpoint.SuperSpeedEndpointCompanion() as d:
  36. d.bMaxBurst = 1
  37. This adds the relevant descriptor, automatically.
  38. """
  39. descriptor = SuperSpeedEndpointCompanionDescriptorEmitter()
  40. yield descriptor
  41. self.add_subordinate_descriptor(descriptor)
  42. class InterfaceDescriptorEmitter(ComplexDescriptorEmitter):
  43. """ Emitter that creates an InterfaceDescriptor. """
  44. DESCRIPTOR_FORMAT = InterfaceDescriptor
  45. @contextmanager
  46. def EndpointDescriptor(self, *, add_default_superspeed=False):
  47. """ Context manager that allows addition of a subordinate endpoint descriptor.
  48. It can be used with a `with` statement; and yields an EndpointDesriptorEmitter
  49. that can be populated:
  50. with interface.EndpointDescriptor() as d:
  51. d.bEndpointAddress = 0x01
  52. d.bmAttributes = 0x80
  53. d.wMaxPacketSize = 64
  54. d.bInterval = 0
  55. This adds the relevant descriptor, automatically.
  56. """
  57. descriptor = EndpointDescriptorEmitter()
  58. yield descriptor
  59. # If we're adding a default SuperSpeed extension, do so.
  60. if add_default_superspeed:
  61. with descriptor.SuperSpeedCompanion():
  62. pass
  63. self.add_subordinate_descriptor(descriptor)
  64. def _pre_emit(self):
  65. # Count our endpoints, and update our internal count.
  66. self.bNumEndpoints = self._type_counts[StandardDescriptorNumbers.ENDPOINT]
  67. # Ensure that our interface string is an index, if we can.
  68. if self._collection and hasattr(self, 'iInterface'):
  69. self.iInterface = self._collection.ensure_string_field_is_index(self.iInterface)
  70. class ConfigurationDescriptorEmitter(ComplexDescriptorEmitter):
  71. """ Emitter that creates a configuration descriptor. """
  72. DESCRIPTOR_FORMAT = ConfigurationDescriptor
  73. @contextmanager
  74. def InterfaceDescriptor(self):
  75. """ Context manager that allows addition of a subordinate interface descriptor.
  76. It can be used with a `with` statement; and yields an InterfaceDescriptorEmitter
  77. that can be populated:
  78. with interface.InterfaceDescriptor() as d:
  79. d.bInterfaceNumber = 0x01
  80. [snip]
  81. This adds the relevant descriptor, automatically. Note that populating derived
  82. fields such as bNumEndpoints aren't necessary; they'll be populated automatically.
  83. """
  84. descriptor = InterfaceDescriptorEmitter(collection=self._collection)
  85. yield descriptor
  86. self.add_subordinate_descriptor(descriptor)
  87. def _pre_emit(self):
  88. # Count our interfaces.
  89. self.bNumInterfaces = self._type_counts[StandardDescriptorNumbers.INTERFACE]
  90. # Figure out our total length.
  91. subordinate_length = sum(len(sub) for sub in self._subordinates)
  92. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  93. # Ensure that our configuration string is an index, if we can.
  94. if self._collection and hasattr(self, 'iConfiguration'):
  95. self.iConfiguration = self._collection.ensure_string_field_is_index(self.iConfiguration)
  96. class DeviceDescriptorCollection:
  97. """ Object that builds a full collection of descriptors related to a given USB device. """
  98. # Most systems seem happiest with en_US (ugh), so default to that.
  99. DEFAULT_SUPPORTED_LANGUAGES = [LanguageIDs.ENGLISH_US]
  100. def __init__(self, automatic_language_descriptor=True):
  101. """
  102. Parameters:
  103. automatic_language_descriptor -- If set or not provided, a language descriptor will automatically
  104. be added if none exists.
  105. """
  106. self._automatic_language_descriptor = automatic_language_descriptor
  107. # Create our internal descriptor tracker.
  108. # Keys are a tuple of (type, index).
  109. self._descriptors = {}
  110. # Track string descriptors as they're created.
  111. self._next_string_index = 1
  112. self._index_for_string = {}
  113. def ensure_string_field_is_index(self, field_value):
  114. """ Processes the given field value; if it's not an string index, converts it to one.
  115. Non-index-fields are converted to indices using `get_index_for_string`, which automatically
  116. adds the relevant fields to our string descriptor collection.
  117. """
  118. if isinstance(field_value, str):
  119. return self.get_index_for_string(field_value)
  120. else:
  121. return field_value
  122. def get_index_for_string(self, string):
  123. """ Returns an string descriptor index for the given string.
  124. If a string descriptor already exists for the given string, its index is
  125. returned. Otherwise, a string descriptor is created.
  126. """
  127. # If we already have a descriptor for this string, return it.
  128. if string in self._index_for_string:
  129. return self._index_for_string[string]
  130. # Otherwise, create one:
  131. # Allocate an index...
  132. index = self._next_string_index
  133. self._index_for_string[string] = index
  134. self._next_string_index += 1
  135. # ... store our string descriptor with it ...
  136. identifier = StandardDescriptorNumbers.STRING, index
  137. self._descriptors[identifier] = get_string_descriptor(string)
  138. # ... and return our index.
  139. return index
  140. def add_descriptor(self, descriptor, index=0):
  141. """ Adds a descriptor to our collection.
  142. Parameters:
  143. descriptor -- The descriptor to be added.
  144. index -- The index of the relevant descriptor. Defaults to 0.
  145. """
  146. # If this is an emitter rather than a descriptor itself, convert it.
  147. if hasattr(descriptor, 'emit'):
  148. descriptor = descriptor.emit()
  149. # Figure out the identifier (type + index) for this descriptor...
  150. descriptor_type = descriptor[1]
  151. identifier = descriptor_type, index
  152. # ... and store it.
  153. self._descriptors[identifier] = descriptor
  154. def add_language_descriptor(self, supported_languages=None):
  155. """ Adds a language descriptor to the list of device descriptors.
  156. Parameters:
  157. supported_languages -- A list of languages supported by the device.
  158. """
  159. if supported_languages is None:
  160. supported_languages = self.DEFAULT_SUPPORTED_LANGUAGES
  161. descriptor = StringLanguageDescriptorEmitter()
  162. descriptor.wLANGID = supported_languages
  163. self.add_descriptor(descriptor)
  164. @contextmanager
  165. def DeviceDescriptor(self):
  166. """ Context manager that allows addition of a device descriptor.
  167. It can be used with a `with` statement; and yields an DeviceDescriptorEmitter
  168. that can be populated:
  169. with collection.DeviceDescriptor() as d:
  170. d.idVendor = 0xabcd
  171. d.idProduct = 0x1234
  172. [snip]
  173. This adds the relevant descriptor, automatically.
  174. """
  175. descriptor = DeviceDescriptorEmitter()
  176. yield descriptor
  177. # If we have any string fields, ensure that they're indices before continuing.
  178. for field in ('iManufacturer', 'iProduct', 'iSerialNumber'):
  179. if hasattr(descriptor, field):
  180. value = getattr(descriptor, field)
  181. index = self.ensure_string_field_is_index(value)
  182. setattr(descriptor, field, index)
  183. self.add_descriptor(descriptor)
  184. @contextmanager
  185. def ConfigurationDescriptor(self):
  186. """ Context manager that allows addition of a configuration descriptor.
  187. It can be used with a `with` statement; and yields an ConfigurationDescriptorEmitter
  188. that can be populated:
  189. with collection.ConfigurationDescriptor() as d:
  190. d.bConfigurationValue = 1
  191. [snip]
  192. This adds the relevant descriptor, automatically. Note that populating derived
  193. fields such as bNumInterfaces aren't necessary; they'll be populated automatically.
  194. """
  195. descriptor = ConfigurationDescriptorEmitter(collection=self)
  196. yield descriptor
  197. self.add_descriptor(descriptor)
  198. def _ensure_has_language_descriptor(self):
  199. """ ensures that we have a language descriptor; adding one if necessary."""
  200. # if we're not automatically adding a language descriptor, we shouldn't do anything,
  201. # and we'll just ignore this.
  202. if not self._automatic_language_descriptor:
  203. return
  204. # if we don't have a language descriptor, add our default one.
  205. if (StandardDescriptorNumbers.STRING, 0) not in self._descriptors:
  206. self.add_language_descriptor()
  207. def get_descriptor_bytes(self, type_number: int, index: int = 0):
  208. """ Returns the raw, binary descriptor for a given descriptor type/index.
  209. Parmeters:
  210. type_number -- The descriptor type number.
  211. index -- The index of the relevant descriptor, if relevant.
  212. """
  213. # If this is a request for a language descriptor, return one.
  214. if (type_number, index) == (StandardDescriptorNumbers.STRING, 0):
  215. self._ensure_has_language_descriptor()
  216. return self._descriptors[(type_number, index)]
  217. def __iter__(self):
  218. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  219. self._ensure_has_language_descriptor()
  220. return ((number, index, desc) for ((number, index), desc) in self._descriptors.items())
  221. class BinaryObjectStoreDescriptorEmitter(ComplexDescriptorEmitter):
  222. """ Emitter that creates a BinaryObjectStore descriptor. """
  223. DESCRIPTOR_FORMAT = BinaryObjectStoreDescriptor
  224. @contextmanager
  225. def USB2Extension(self):
  226. """ Context manager that allows addition of a USB 2.0 Extension to this Binary Object Store.
  227. It can be used with a `with` statement; and yields an USB2ExtensionDescriptorEmitter
  228. that can be populated:
  229. with bos.USB2Extension() as e:
  230. e.bmAttributes = 1
  231. This adds the relevant descriptor, automatically.
  232. """
  233. descriptor = USB2ExtensionDescriptorEmitter()
  234. yield descriptor
  235. self.add_subordinate_descriptor(descriptor)
  236. @contextmanager
  237. def SuperSpeedUSBDeviceCapability(self):
  238. """ Context manager that allows addition of a SS Device Capability to this Binary Object Store.
  239. It can be used with a `with` statement; and yields an SuperSpeedUSBDeviceCapabilityDescriptorEmitter
  240. that can be populated:
  241. with bos.SuperSpeedUSBDeviceCapability() as e:
  242. e.wSpeedSupported = 0b1110
  243. e.bFunctionalitySupport = 1
  244. This adds the relevant descriptor, automatically.
  245. """
  246. descriptor = SuperSpeedUSBDeviceCapabilityDescriptorEmitter()
  247. yield descriptor
  248. self.add_subordinate_descriptor(descriptor)
  249. def _pre_emit(self):
  250. # Figure out the total length of our descriptor, including subordinates.
  251. subordinate_length = sum(len(sub) for sub in self._subordinates)
  252. self.wTotalLength = subordinate_length + self.DESCRIPTOR_FORMAT.sizeof()
  253. # Count our subordinate descriptors, and update our internal count.
  254. self.bNumDeviceCaps = len(self._subordinates)
  255. class SuperSpeedDeviceDescriptorCollection(DeviceDescriptorCollection):
  256. """ Object that builds a full collection of descriptors related to a given USB3 device. """
  257. def __init__(self, automatic_descriptors=True):
  258. """
  259. Parameters:
  260. automatic_descriptors -- If set or not provided, certian required descriptors will be
  261. be added if none exists.
  262. """
  263. self._automatic_descriptors = automatic_descriptors
  264. super().__init__(automatic_language_descriptor=automatic_descriptors)
  265. @contextmanager
  266. def BOSDescriptor(self):
  267. """ Context manager that allows addition of a Binary Object Store descriptor.
  268. It can be used with a `with` statement; and yields an BinaryObjectStoreDescriptorEmitter
  269. that can be populated:
  270. with collection.BOSDescriptor() as d:
  271. [snip]
  272. This adds the relevant descriptor, automatically. Note that populating derived
  273. fields such as bNumDeviceCaps aren't necessary; they'll be populated automatically.
  274. """
  275. descriptor = BinaryObjectStoreDescriptorEmitter()
  276. yield descriptor
  277. self.add_descriptor(descriptor)
  278. def add_default_bos_descriptor(self):
  279. """ Adds a default, empty BOS descriptor. """
  280. # Create an empty BOS descriptor...
  281. descriptor = BinaryObjectStoreDescriptorEmitter()
  282. # ... populate our default required descriptors...
  283. descriptor.add_subordinate_descriptor(USB2ExtensionDescriptorEmitter())
  284. descriptor.add_subordinate_descriptor(SuperSpeedUSBDeviceCapabilityDescriptorEmitter())
  285. # ... and add it to our overall BOS descriptor.
  286. self.add_descriptor(descriptor)
  287. def _ensure_has_bos_descriptor(self):
  288. """ Ensures that we have a BOS descriptor; adding one if necessary."""
  289. # If we're not automatically adding a language descriptor, we shouldn't do anything,
  290. # and we'll just ignore this.
  291. if not self._automatic_descriptors:
  292. return
  293. # If we don't have a language descriptor, add our default one.
  294. if (StandardDescriptorNumbers.BOS, 0) not in self._descriptors:
  295. self.add_default_bos_descriptor()
  296. def __iter__(self):
  297. """ Allow iterating over each of our descriptors; yields (index, value, descriptor). """
  298. self._ensure_has_bos_descriptor()
  299. return super().__iter__()
  300. class EmitterTests(unittest.TestCase):
  301. def test_string_emitter(self):
  302. emitter = StringDescriptorEmitter()
  303. emitter.bString = "Hello"
  304. self.assertEqual(emitter.emit(), b"\x0C\x03H\0e\0l\0l\0o\0")
  305. def test_string_emitter_function(self):
  306. self.assertEqual(get_string_descriptor("Hello"), b"\x0C\x03H\0e\0l\0l\0o\0")
  307. def test_configuration_emitter(self):
  308. descriptor = bytes([
  309. # config descriptor
  310. 12, # length
  311. 2, # type
  312. 25, 00, # total length
  313. 1, # num interfaces
  314. 1, # configuration number
  315. 0, # config string
  316. 0x80, # attributes
  317. 250, # max power
  318. # interface descriptor
  319. 9, # length
  320. 4, # type
  321. 0, # number
  322. 0, # alternate
  323. 1, # num endpoints
  324. 0xff, # class
  325. 0xff, # subclass
  326. 0xff, # protocol
  327. 0, # string
  328. # endpoint descriptor
  329. 7, # length
  330. 5, # type
  331. 0x01, # address
  332. 2, # attributes
  333. 64, 0, # max packet size
  334. 255, # interval
  335. ])
  336. # Create a trivial configuration descriptor...
  337. emitter = ConfigurationDescriptorEmitter()
  338. with emitter.InterfaceDescriptor() as interface:
  339. interface.bInterfaceNumber = 0
  340. with interface.EndpointDescriptor() as endpoint:
  341. endpoint.bEndpointAddress = 1
  342. # ... and validate that it maches our reference descriptor.
  343. binary = emitter.emit()
  344. self.assertEqual(len(binary), len(descriptor))
  345. def test_descriptor_collection(self):
  346. collection = DeviceDescriptorCollection()
  347. with collection.DeviceDescriptor() as d:
  348. d.idVendor = 0xdead
  349. d.idProduct = 0xbeef
  350. d.bNumConfigurations = 1
  351. d.iManufacturer = "Test Company"
  352. d.iProduct = "Test Product"
  353. with collection.ConfigurationDescriptor() as c:
  354. c.bConfigurationValue = 1
  355. with c.InterfaceDescriptor() as i:
  356. i.bInterfaceNumber = 1
  357. with i.EndpointDescriptor() as e:
  358. e.bEndpointAddress = 0x81
  359. with i.EndpointDescriptor() as e:
  360. e.bEndpointAddress = 0x01
  361. results = list(collection)
  362. # We should wind up with four descriptor entries, as our endpoint/interface descriptors are
  363. # included in our configuration descriptor.
  364. self.assertEqual(len(results), 5)
  365. # Supported languages string.
  366. self.assertIn((3, 0, b'\x04\x03\x09\x04'), results)
  367. # Manufacturer / product string.
  368. self.assertIn((3, 1, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00C\x00o\x00m\x00p\x00a\x00n\x00y\x00'), results)
  369. self.assertIn((3, 2, b'\x1a\x03T\x00e\x00s\x00t\x00 \x00P\x00r\x00o\x00d\x00u\x00c\x00t\x00'), results)
  370. # Device descriptor.
  371. self.assertIn((1, 0, b'\x12\x01\x00\x02\x00\x00\x00@\xad\xde\xef\xbe\x00\x00\x01\x02\x00\x01'), results)
  372. # Configuration descriptor, with subordinates.
  373. self.assertIn((2, 0, b'\t\x02 \x00\x01\x01\x00\x80\xfa\t\x04\x01\x00\x02\xff\xff\xff\x00\x07\x05\x81\x02@\x00\xff\x07\x05\x01\x02@\x00\xff'), results)
  374. def test_empty_descriptor_collection(self):
  375. collection = DeviceDescriptorCollection(automatic_language_descriptor=False)
  376. results = list(collection)
  377. self.assertEqual(len(results), 0)
  378. def test_automatic_language_descriptor(self):
  379. collection = DeviceDescriptorCollection(automatic_language_descriptor=True)
  380. results = list(collection)
  381. self.assertEqual(len(results), 1)
  382. if __name__ == "__main__":
  383. unittest.main()