From 83c5061700ee8e2a965e8e7903f7234c6fb4515f Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 12 Sep 2025 14:55:45 +0800 Subject: [PATCH] Handle ISO data path race condition --- bumble/controller.py | 46 ++++++++++++++++++++++++++++--- bumble/device.py | 64 +++++++++++++++++++++++++++----------------- tests/bap_test.py | 36 +++++++++++-------------- 3 files changed, 98 insertions(+), 48 deletions(-) diff --git a/bumble/controller.py b/bumble/controller.py index d148495..616ee1b 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -33,7 +33,6 @@ from bumble.hci import ( HCI_COMMAND_DISALLOWED_ERROR, HCI_COMMAND_PACKET, HCI_COMMAND_STATUS_PENDING, - HCI_CONNECTION_TIMEOUT_ERROR, HCI_CONTROLLER_BUSY_ERROR, HCI_EVENT_PACKET, HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR, @@ -88,6 +87,7 @@ class CisLink: cis_id: int cig_id: int acl_connection: Optional[Connection] = None + data_paths: set[int] = dataclasses.field(default_factory=set) # ----------------------------------------------------------------------------- @@ -381,6 +381,11 @@ class Controller: return connection return None + def find_iso_link_by_handle(self, handle: int) -> Optional[CisLink]: + return self.central_cis_links.get(handle) or self.peripheral_cis_links.get( + handle + ) + def on_link_central_connected(self, central_address): ''' Called when an incoming connection occurs from a central on the link @@ -1853,16 +1858,51 @@ class Controller: ) ) - def on_hci_le_setup_iso_data_path_command(self, command): + def on_hci_le_setup_iso_data_path_command( + self, command: hci.HCI_LE_Setup_ISO_Data_Path_Command + ) -> bytes: ''' See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command ''' + if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)): + return struct.pack( + ' bytes: ''' See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command ''' + if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)): + return struct.pack( + ' None: + self._data_path_lock = asyncio.Lock() + self.data_paths = set() + async def setup_data_path( self, direction: _IsoLink.Direction, @@ -1482,37 +1488,45 @@ class _IsoLink: Raises: HCI_Error: When command complete status is not HCI_SUCCESS. """ - await self.device.send_command( - hci.HCI_LE_Setup_ISO_Data_Path_Command( - connection_handle=self.handle, - data_path_direction=direction, - data_path_id=data_path_id, - codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT), - controller_delay=controller_delay, - codec_configuration=codec_configuration, - ), - check_result=True, - ) + async with self._data_path_lock: + if direction in self.data_paths: + return + await self.device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=self.handle, + data_path_direction=direction, + data_path_id=data_path_id, + codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT), + controller_delay=controller_delay, + codec_configuration=codec_configuration, + ), + check_result=True, + ) + self.data_paths.add(direction) - async def remove_data_path(self, directions: Iterable[_IsoLink.Direction]) -> int: + async def remove_data_path(self, directions: Iterable[_IsoLink.Direction]) -> None: """Remove a data path with controller on given direction. Args: direction: Direction of data path. - Returns: - Command status. + Raises: + HCI_Error: When command complete status is not HCI_SUCCESS. """ - response = await self.device.send_command( - hci.HCI_LE_Remove_ISO_Data_Path_Command( - connection_handle=self.handle, - data_path_direction=sum( - 1 << direction for direction in set(directions) + async with self._data_path_lock: + directions_to_remove = set(directions).intersection(self.data_paths) + if not directions_to_remove: + return + await self.device.send_command( + hci.HCI_LE_Remove_ISO_Data_Path_Command( + connection_handle=self.handle, + data_path_direction=sum( + 1 << direction for direction in directions_to_remove + ), ), - ), - check_result=False, - ) - return response.return_parameters.status + check_result=True, + ) + self.data_paths.difference_update(directions_to_remove) def write(self, sdu: bytes) -> None: """Write an ISO SDU.""" @@ -1622,7 +1636,8 @@ class CisLink(utils.EventEmitter, _IsoLink): EVENT_ESTABLISHMENT_FAILURE: ClassVar[str] = "establishment_failure" def __post_init__(self) -> None: - super().__init__() + utils.EventEmitter.__init__(self) + _IsoLink.__init__(self) async def disconnect( self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR @@ -1638,6 +1653,7 @@ class BisLink(_IsoLink): sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None def __post_init__(self) -> None: + super().__init__() self.device = self.big.device diff --git a/tests/bap_test.py b/tests/bap_test.py index 2e78102..03bff72 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -310,12 +310,12 @@ async def test_pacs(): @pytest.mark.asyncio async def test_ascs(): devices = TwoDevices() - devices[0].add_service( - AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2]) + devices[1].add_service( + AudioStreamControlService(device=devices[1], sink_ase_id=[1, 2]) ) await devices.setup_connection() - peer = device.Peer(devices.connections[1]) + peer = device.Peer(devices.connections[0]) ascs_client = await peer.discover_service_and_create_proxy( AudioStreamControlServiceProxy ) @@ -369,7 +369,7 @@ async def test_ascs(): await ascs_client.ase_control_point.write_value( ASE_Config_QOS( ase_id=[1, 2], - cig_id=[1, 2], + cig_id=[1, 1], cis_id=[3, 4], sdu_interval=[5, 6], framing=[0, 1], @@ -402,25 +402,19 @@ async def test_ascs(): ) # CIS establishment - devices[0].emit( - 'cis_establishment', - device.CisLink( - device=devices[0], - acl_connection=devices.connections[0], - handle=5, - cis_id=3, + cis_handles = await devices[0].setup_cig( + device.CigParameters( cig_id=1, - ), + cis_parameters=[ + device.CigParameters.CisParameters(cis_id=3), + device.CigParameters.CisParameters(cis_id=4), + ], + sdu_interval_c_to_p=0, + sdu_interval_p_to_c=0, + ) ) - devices[0].emit( - 'cis_establishment', - device.CisLink( - device=devices[0], - acl_connection=devices.connections[0], - handle=6, - cis_id=4, - cig_id=2, - ), + await devices[0].create_cis( + [(cis_handle, devices.connections[0]) for cis_handle in cis_handles] ) assert (await notifications[1].get())[:2] == bytes( [1, AseStateMachine.State.STREAMING]