forked from auracaster/bumble_mirror
Handle ISO data path race condition
This commit is contained in:
@@ -33,7 +33,6 @@ from bumble.hci import (
|
||||
HCI_COMMAND_DISALLOWED_ERROR,
|
||||
HCI_COMMAND_PACKET,
|
||||
HCI_COMMAND_STATUS_PENDING,
|
||||
HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
HCI_CONTROLLER_BUSY_ERROR,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
|
||||
@@ -88,6 +87,7 @@ class CisLink:
|
||||
cis_id: int
|
||||
cig_id: int
|
||||
acl_connection: Optional[Connection] = None
|
||||
data_paths: set[int] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -381,6 +381,11 @@ class Controller:
|
||||
return connection
|
||||
return None
|
||||
|
||||
def find_iso_link_by_handle(self, handle: int) -> Optional[CisLink]:
|
||||
return self.central_cis_links.get(handle) or self.peripheral_cis_links.get(
|
||||
handle
|
||||
)
|
||||
|
||||
def on_link_central_connected(self, central_address):
|
||||
'''
|
||||
Called when an incoming connection occurs from a central on the link
|
||||
@@ -1853,16 +1858,51 @@ class Controller:
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_setup_iso_data_path_command(self, command):
|
||||
def on_hci_le_setup_iso_data_path_command(
|
||||
self, command: hci.HCI_LE_Setup_ISO_Data_Path_Command
|
||||
) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
|
||||
'''
|
||||
if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)):
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
if command.data_path_direction in iso_link.data_paths:
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_COMMAND_DISALLOWED_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
iso_link.data_paths.add(command.data_path_direction)
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_remove_iso_data_path_command(self, command):
|
||||
def on_hci_le_remove_iso_data_path_command(
|
||||
self, command: hci.HCI_LE_Remove_ISO_Data_Path_Command
|
||||
) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
|
||||
'''
|
||||
if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)):
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
data_paths: set[int] = set(
|
||||
direction
|
||||
for direction in hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction
|
||||
if (1 << direction) & command.data_path_direction
|
||||
)
|
||||
if not data_paths.issubset(iso_link.data_paths):
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_COMMAND_DISALLOWED_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
iso_link.data_paths.difference_update(data_paths)
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_set_host_feature_command(
|
||||
|
||||
@@ -1453,6 +1453,8 @@ class _IsoLink:
|
||||
handle: int
|
||||
device: Device
|
||||
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
|
||||
data_paths: set[_IsoLink.Direction]
|
||||
_data_path_lock: asyncio.Lock
|
||||
|
||||
class Direction(IntEnum):
|
||||
HOST_TO_CONTROLLER = (
|
||||
@@ -1462,6 +1464,10 @@ class _IsoLink:
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data_path_lock = asyncio.Lock()
|
||||
self.data_paths = set()
|
||||
|
||||
async def setup_data_path(
|
||||
self,
|
||||
direction: _IsoLink.Direction,
|
||||
@@ -1482,37 +1488,45 @@ class _IsoLink:
|
||||
Raises:
|
||||
HCI_Error: When command complete status is not HCI_SUCCESS.
|
||||
"""
|
||||
await self.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=self.handle,
|
||||
data_path_direction=direction,
|
||||
data_path_id=data_path_id,
|
||||
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=controller_delay,
|
||||
codec_configuration=codec_configuration,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
async with self._data_path_lock:
|
||||
if direction in self.data_paths:
|
||||
return
|
||||
await self.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=self.handle,
|
||||
data_path_direction=direction,
|
||||
data_path_id=data_path_id,
|
||||
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=controller_delay,
|
||||
codec_configuration=codec_configuration,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
self.data_paths.add(direction)
|
||||
|
||||
async def remove_data_path(self, directions: Iterable[_IsoLink.Direction]) -> int:
|
||||
async def remove_data_path(self, directions: Iterable[_IsoLink.Direction]) -> None:
|
||||
"""Remove a data path with controller on given direction.
|
||||
|
||||
Args:
|
||||
direction: Direction of data path.
|
||||
|
||||
Returns:
|
||||
Command status.
|
||||
Raises:
|
||||
HCI_Error: When command complete status is not HCI_SUCCESS.
|
||||
"""
|
||||
response = await self.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.handle,
|
||||
data_path_direction=sum(
|
||||
1 << direction for direction in set(directions)
|
||||
async with self._data_path_lock:
|
||||
directions_to_remove = set(directions).intersection(self.data_paths)
|
||||
if not directions_to_remove:
|
||||
return
|
||||
await self.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.handle,
|
||||
data_path_direction=sum(
|
||||
1 << direction for direction in directions_to_remove
|
||||
),
|
||||
),
|
||||
),
|
||||
check_result=False,
|
||||
)
|
||||
return response.return_parameters.status
|
||||
check_result=True,
|
||||
)
|
||||
self.data_paths.difference_update(directions_to_remove)
|
||||
|
||||
def write(self, sdu: bytes) -> None:
|
||||
"""Write an ISO SDU."""
|
||||
@@ -1622,7 +1636,8 @@ class CisLink(utils.EventEmitter, _IsoLink):
|
||||
EVENT_ESTABLISHMENT_FAILURE: ClassVar[str] = "establishment_failure"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
utils.EventEmitter.__init__(self)
|
||||
_IsoLink.__init__(self)
|
||||
|
||||
async def disconnect(
|
||||
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
|
||||
@@ -1638,6 +1653,7 @@ class BisLink(_IsoLink):
|
||||
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.device = self.big.device
|
||||
|
||||
|
||||
|
||||
@@ -310,12 +310,12 @@ async def test_pacs():
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascs():
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
|
||||
devices[1].add_service(
|
||||
AudioStreamControlService(device=devices[1], sink_ase_id=[1, 2])
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
peer = device.Peer(devices.connections[1])
|
||||
peer = device.Peer(devices.connections[0])
|
||||
ascs_client = await peer.discover_service_and_create_proxy(
|
||||
AudioStreamControlServiceProxy
|
||||
)
|
||||
@@ -369,7 +369,7 @@ async def test_ascs():
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Config_QOS(
|
||||
ase_id=[1, 2],
|
||||
cig_id=[1, 2],
|
||||
cig_id=[1, 1],
|
||||
cis_id=[3, 4],
|
||||
sdu_interval=[5, 6],
|
||||
framing=[0, 1],
|
||||
@@ -402,25 +402,19 @@ async def test_ascs():
|
||||
)
|
||||
|
||||
# CIS establishment
|
||||
devices[0].emit(
|
||||
'cis_establishment',
|
||||
device.CisLink(
|
||||
device=devices[0],
|
||||
acl_connection=devices.connections[0],
|
||||
handle=5,
|
||||
cis_id=3,
|
||||
cis_handles = await devices[0].setup_cig(
|
||||
device.CigParameters(
|
||||
cig_id=1,
|
||||
),
|
||||
cis_parameters=[
|
||||
device.CigParameters.CisParameters(cis_id=3),
|
||||
device.CigParameters.CisParameters(cis_id=4),
|
||||
],
|
||||
sdu_interval_c_to_p=0,
|
||||
sdu_interval_p_to_c=0,
|
||||
)
|
||||
)
|
||||
devices[0].emit(
|
||||
'cis_establishment',
|
||||
device.CisLink(
|
||||
device=devices[0],
|
||||
acl_connection=devices.connections[0],
|
||||
handle=6,
|
||||
cis_id=4,
|
||||
cig_id=2,
|
||||
),
|
||||
await devices[0].create_cis(
|
||||
[(cis_handle, devices.connections[0]) for cis_handle in cis_handles]
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.STREAMING]
|
||||
|
||||
Reference in New Issue
Block a user