forked from auracaster/bumble_mirror
Basic LMP implementation
This commit is contained in:
@@ -23,9 +23,9 @@ import itertools
|
||||
import logging
|
||||
import random
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from bumble import hci
|
||||
from bumble import hci, lmp
|
||||
from bumble.colors import color
|
||||
from bumble.core import PhysicalTransport
|
||||
|
||||
@@ -56,6 +56,14 @@ class CisLink:
|
||||
data_paths: set[int] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class ScoLink:
|
||||
handle: int
|
||||
link_type: int
|
||||
peer_address: hci.Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Connection:
|
||||
@@ -66,6 +74,7 @@ class Connection:
|
||||
link: Any
|
||||
transport: int
|
||||
link_type: int
|
||||
classic_allow_role_switch: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
@@ -96,6 +105,8 @@ class Controller:
|
||||
hci.Address, Connection
|
||||
] # Connections where this controller is the peripheral
|
||||
classic_connections: dict[hci.Address, Connection] # Connections in BR/EDR
|
||||
classic_pending_commands: dict[hci.Address, dict[lmp.Opcode, asyncio.Future[int]]]
|
||||
sco_links: dict[hci.Address, ScoLink] # SCO links by address
|
||||
central_cis_links: dict[int, CisLink] # CIS links by handle
|
||||
peripheral_cis_links: dict[int, CisLink] # CIS links by handle
|
||||
|
||||
@@ -151,6 +162,7 @@ class Controller:
|
||||
advertising_data: Optional[bytes] = None
|
||||
advertising_timer_handle: Optional[asyncio.Handle] = None
|
||||
classic_scan_enable: int = 0
|
||||
classic_allow_role_switch: bool = True
|
||||
|
||||
_random_address: hci.Address = hci.Address('00:00:00:00:00:00')
|
||||
|
||||
@@ -167,6 +179,8 @@ class Controller:
|
||||
self.central_connections = {}
|
||||
self.peripheral_connections = {}
|
||||
self.classic_connections = {}
|
||||
self.sco_links = {}
|
||||
self.classic_pending_commands = {}
|
||||
self.central_cis_links = {}
|
||||
self.peripheral_cis_links = {}
|
||||
self.default_phy = {
|
||||
@@ -293,7 +307,7 @@ class Controller:
|
||||
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
|
||||
)
|
||||
if self.host:
|
||||
self.host.on_packet(bytes(packet))
|
||||
asyncio.get_running_loop().call_soon(self.host.on_packet, bytes(packet))
|
||||
|
||||
# This method allows the controller to emulate the same API as a transport source
|
||||
async def wait_for_termination(self) -> None:
|
||||
@@ -303,25 +317,20 @@ class Controller:
|
||||
# Link connections
|
||||
############################################################
|
||||
def allocate_connection_handle(self) -> int:
|
||||
handle = 0
|
||||
max_handle = 0
|
||||
for connection in itertools.chain(
|
||||
self.central_connections.values(),
|
||||
self.peripheral_connections.values(),
|
||||
self.classic_connections.values(),
|
||||
):
|
||||
max_handle = max(max_handle, connection.handle)
|
||||
if connection.handle == handle:
|
||||
# Already used, continue searching after the current max
|
||||
handle = max_handle + 1
|
||||
for cis_handle in itertools.chain(
|
||||
self.central_cis_links.keys(), self.peripheral_cis_links.keys()
|
||||
):
|
||||
max_handle = max(max_handle, cis_handle)
|
||||
if cis_handle == handle:
|
||||
# Already used, continue searching after the current max
|
||||
handle = max_handle + 1
|
||||
return handle
|
||||
current_handles = set(
|
||||
cast(Connection | CisLink | ScoLink, link).handle
|
||||
for link in itertools.chain(
|
||||
self.central_connections.values(),
|
||||
self.peripheral_connections.values(),
|
||||
self.classic_connections.values(),
|
||||
self.sco_links.values(),
|
||||
self.central_cis_links.values(),
|
||||
self.peripheral_cis_links.values(),
|
||||
)
|
||||
)
|
||||
return next(
|
||||
handle for handle in range(0xEFF + 1) if handle not in current_handles
|
||||
)
|
||||
|
||||
def find_le_connection_by_address(
|
||||
self, address: hci.Address
|
||||
@@ -363,6 +372,12 @@ class Controller:
|
||||
return connection
|
||||
return None
|
||||
|
||||
def find_classic_sco_link_by_handle(self, handle: int) -> Optional[ScoLink]:
|
||||
for connection in self.sco_links.values():
|
||||
if connection.handle == handle:
|
||||
return connection
|
||||
return None
|
||||
|
||||
def find_iso_link_by_handle(self, handle: int) -> Optional[CisLink]:
|
||||
return self.central_cis_links.get(handle) or self.peripheral_cis_links.get(
|
||||
handle
|
||||
@@ -669,9 +684,75 @@ class Controller:
|
||||
# Classic link connections
|
||||
############################################################
|
||||
|
||||
def send_lmp_packet(
|
||||
self, receiver_address: hci.Address, packet: lmp.Packet
|
||||
) -> asyncio.Future[int]:
|
||||
loop = asyncio.get_running_loop()
|
||||
assert self.link
|
||||
self.link.send_lmp_packet(self, receiver_address, packet)
|
||||
future = self.classic_pending_commands.setdefault(receiver_address, {})[
|
||||
packet.opcode
|
||||
] = loop.create_future()
|
||||
return future
|
||||
|
||||
def on_lmp_packet(self, sender_address: hci.Address, packet: lmp.Packet):
|
||||
if isinstance(packet, (lmp.LmpAccepted, lmp.LmpAcceptedExt)):
|
||||
if future := self.classic_pending_commands.setdefault(
|
||||
sender_address, {}
|
||||
).get(packet.response_opcode):
|
||||
future.set_result(hci.HCI_SUCCESS)
|
||||
else:
|
||||
logger.error("!!! Unhandled packet: %s", packet)
|
||||
elif isinstance(packet, (lmp.LmpNotAccepted, lmp.LmpNotAcceptedExt)):
|
||||
if future := self.classic_pending_commands.setdefault(
|
||||
sender_address, {}
|
||||
).get(packet.response_opcode):
|
||||
future.set_result(packet.error_code)
|
||||
else:
|
||||
logger.error("!!! Unhandled packet: %s", packet)
|
||||
elif isinstance(packet, (lmp.LmpHostConnectionReq)):
|
||||
self.on_classic_connection_request(
|
||||
sender_address, hci.HCI_Connection_Complete_Event.LinkType.ACL
|
||||
)
|
||||
elif isinstance(packet, (lmp.LmpScoLinkReq)):
|
||||
self.on_classic_connection_request(
|
||||
sender_address, hci.HCI_Connection_Complete_Event.LinkType.SCO
|
||||
)
|
||||
elif isinstance(packet, (lmp.LmpEscoLinkReq)):
|
||||
self.on_classic_connection_request(
|
||||
sender_address, hci.HCI_Connection_Complete_Event.LinkType.ESCO
|
||||
)
|
||||
elif isinstance(packet, (lmp.LmpDetach)):
|
||||
self.on_classic_disconnected(
|
||||
sender_address, hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
|
||||
)
|
||||
elif isinstance(packet, (lmp.LmpSwitchReq)):
|
||||
self.on_classic_role_change_request(sender_address)
|
||||
elif isinstance(packet, (lmp.LmpRemoveScoLinkReq, lmp.LmpRemoveEscoLinkReq)):
|
||||
self.on_classic_sco_disconnected(sender_address, packet.error_code)
|
||||
else:
|
||||
logger.error("!!! Unhandled packet: %s", packet)
|
||||
|
||||
def on_classic_connection_request(
|
||||
self, peer_address: hci.Address, link_type: int
|
||||
) -> None:
|
||||
if link_type == hci.HCI_Connection_Complete_Event.LinkType.ACL:
|
||||
self.classic_connections[peer_address] = Connection(
|
||||
controller=self,
|
||||
handle=0,
|
||||
role=hci.Role.PERIPHERAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=PhysicalTransport.BR_EDR,
|
||||
link_type=link_type,
|
||||
classic_allow_role_switch=self.classic_allow_role_switch,
|
||||
)
|
||||
else:
|
||||
self.sco_links[peer_address] = ScoLink(
|
||||
handle=0,
|
||||
link_type=link_type,
|
||||
peer_address=peer_address,
|
||||
)
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Connection_Request_Event(
|
||||
bd_addr=peer_address,
|
||||
@@ -686,13 +767,13 @@ class Controller:
|
||||
if status == hci.HCI_SUCCESS:
|
||||
# Allocate (or reuse) a connection handle
|
||||
peer_address = peer_address
|
||||
connection = self.classic_connections.get(peer_address)
|
||||
if connection is None:
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
if connection := self.classic_connections.get(peer_address):
|
||||
connection.handle = connection_handle
|
||||
else:
|
||||
connection = Connection(
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
# hci.Role doesn't matter in Classic because they are managed by hci.HCI_Role_Change and hci.HCI_Role_Discovery
|
||||
role=hci.Role.CENTRAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
@@ -703,8 +784,6 @@ class Controller:
|
||||
logger.debug(
|
||||
f'New CLASSIC connection handle: 0x{connection_handle:04X}'
|
||||
)
|
||||
else:
|
||||
connection_handle = connection.handle
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Connection_Complete_Event(
|
||||
status=status,
|
||||
@@ -728,7 +807,7 @@ class Controller:
|
||||
|
||||
def on_classic_disconnected(self, peer_address: hci.Address, reason: int) -> None:
|
||||
# Send a disconnection complete event
|
||||
if connection := self.classic_connections.get(peer_address):
|
||||
if connection := self.classic_connections.pop(peer_address, None):
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Disconnection_Complete_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
@@ -736,17 +815,51 @@ class Controller:
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove the connection
|
||||
del self.classic_connections[peer_address]
|
||||
else:
|
||||
logger.warning(f'!!! No classic connection found for {peer_address}')
|
||||
|
||||
def on_classic_role_change(self, peer_address: hci.Address, new_role: int) -> None:
|
||||
def on_classic_sco_disconnected(
|
||||
self, peer_address: hci.Address, reason: int
|
||||
) -> None:
|
||||
# Send a disconnection complete event
|
||||
if sco_link := self.sco_links.pop(peer_address, None):
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Disconnection_Complete_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
connection_handle=sco_link.handle,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f'!!! No classic connection found for {peer_address}')
|
||||
|
||||
def on_classic_role_change_request(self, peer_address: hci.Address) -> None:
|
||||
assert (connection := self.classic_connections.get(peer_address))
|
||||
if not connection.classic_allow_role_switch:
|
||||
self.send_lmp_packet(
|
||||
peer_address,
|
||||
lmp.LmpNotAccepted(
|
||||
lmp.Opcode.LMP_SWITCH_REQ, hci.HCI_ROLE_CHANGE_NOT_ALLOWED_ERROR
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.send_lmp_packet(
|
||||
peer_address,
|
||||
lmp.LmpAccepted(lmp.Opcode.LMP_SWITCH_REQ),
|
||||
)
|
||||
self.classic_role_change(connection)
|
||||
|
||||
def classic_role_change(self, connection: Connection) -> None:
|
||||
new_role = (
|
||||
hci.Role.CENTRAL
|
||||
if connection.role == hci.Role.PERIPHERAL
|
||||
else hci.Role.PERIPHERAL
|
||||
)
|
||||
connection.role = new_role
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Role_Change_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
bd_addr=peer_address,
|
||||
bd_addr=connection.peer_address,
|
||||
new_role=new_role,
|
||||
)
|
||||
)
|
||||
@@ -757,17 +870,12 @@ class Controller:
|
||||
if status == hci.HCI_SUCCESS:
|
||||
# Allocate (or reuse) a connection handle
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
controller=self,
|
||||
sco_link = ScoLink(
|
||||
handle=connection_handle,
|
||||
# hci.Role doesn't matter in SCO.
|
||||
role=hci.Role.CENTRAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=PhysicalTransport.BR_EDR,
|
||||
link_type=link_type,
|
||||
peer_address=peer_address,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
self.sco_links[peer_address] = sco_link
|
||||
logger.debug(f'New SCO connection handle: 0x{connection_handle:04X}')
|
||||
else:
|
||||
connection_handle = 0
|
||||
@@ -847,7 +955,16 @@ class Controller:
|
||||
)
|
||||
return None
|
||||
|
||||
self.link.classic_connect(self, command.bd_addr)
|
||||
self.classic_connections[command.bd_addr] = Connection(
|
||||
controller=self,
|
||||
handle=0,
|
||||
role=hci.Role.CENTRAL,
|
||||
peer_address=command.bd_addr,
|
||||
link=self.link,
|
||||
transport=PhysicalTransport.BR_EDR,
|
||||
link_type=hci.HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
classic_allow_role_switch=bool(command.allow_role_switch),
|
||||
)
|
||||
|
||||
# Say that the connection is pending
|
||||
self.send_hci_packet(
|
||||
@@ -857,6 +974,12 @@ class Controller:
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
future = self.send_lmp_packet(command.bd_addr, lmp.LmpHostConnectionReq())
|
||||
|
||||
def on_response(future: asyncio.Future[int]):
|
||||
self.on_classic_connection_complete(command.bd_addr, future.result())
|
||||
|
||||
future.add_done_callback(on_response)
|
||||
return None
|
||||
|
||||
def on_hci_disconnect_command(
|
||||
@@ -894,14 +1017,37 @@ class Controller:
|
||||
del self.peripheral_connections[connection.peer_address]
|
||||
elif connection := self.find_classic_connection_by_handle(handle):
|
||||
if self.link:
|
||||
self.link.classic_disconnect(
|
||||
self,
|
||||
self.send_lmp_packet(
|
||||
connection.peer_address,
|
||||
hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
lmp.LmpDetach(command.reason),
|
||||
)
|
||||
self.on_classic_disconnected(connection.peer_address, command.reason)
|
||||
else:
|
||||
# Remove the connection
|
||||
del self.classic_connections[connection.peer_address]
|
||||
elif sco_link := self.find_classic_sco_link_by_handle(handle):
|
||||
if self.link:
|
||||
if (
|
||||
sco_link.link_type
|
||||
== hci.HCI_Connection_Complete_Event.LinkType.ESCO
|
||||
):
|
||||
self.send_lmp_packet(
|
||||
sco_link.peer_address,
|
||||
lmp.LmpRemoveScoLinkReq(
|
||||
sco_handle=0, error_code=command.reason
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.send_lmp_packet(
|
||||
sco_link.peer_address,
|
||||
lmp.LmpRemoveEscoLinkReq(
|
||||
esco_handle=0, error_code=command.reason
|
||||
),
|
||||
)
|
||||
self.on_classic_sco_disconnected(sco_link.peer_address, command.reason)
|
||||
else:
|
||||
# Remove the connection
|
||||
del self.sco_links[sco_link.peer_address]
|
||||
elif cis_link := (
|
||||
self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle)
|
||||
):
|
||||
@@ -925,6 +1071,16 @@ class Controller:
|
||||
|
||||
if self.link is None:
|
||||
return None
|
||||
|
||||
if not (connection := self.classic_connections.get(command.bd_addr)):
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return None
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
@@ -932,7 +1088,36 @@ class Controller:
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_accept_connection(self, command.bd_addr, command.role)
|
||||
|
||||
if command.role == hci.Role.CENTRAL:
|
||||
# Perform role switching before accept.
|
||||
future = self.send_lmp_packet(command.bd_addr, lmp.LmpSwitchReq())
|
||||
|
||||
def on_response(future: asyncio.Future[int]):
|
||||
if (status := future.result()) == hci.HCI_SUCCESS:
|
||||
self.classic_role_change(connection)
|
||||
# Continue connection setup.
|
||||
self.send_lmp_packet(
|
||||
command.bd_addr,
|
||||
lmp.LmpAccepted(lmp.Opcode.LMP_HOST_CONNECTION_REQ),
|
||||
)
|
||||
else:
|
||||
# Abort connection setup.
|
||||
self.send_lmp_packet(
|
||||
command.bd_addr,
|
||||
lmp.LmpNotAccepted(lmp.Opcode.LMP_HOST_CONNECTION_REQ, status),
|
||||
)
|
||||
self.on_classic_connection_complete(command.bd_addr, status)
|
||||
|
||||
future.add_done_callback(on_response)
|
||||
|
||||
else:
|
||||
# Simply accept connection.
|
||||
self.send_lmp_packet(
|
||||
command.bd_addr,
|
||||
lmp.LmpAccepted(lmp.Opcode.LMP_HOST_CONNECTION_REQ),
|
||||
)
|
||||
self.on_classic_connection_complete(command.bd_addr, hci.HCI_SUCCESS)
|
||||
return None
|
||||
|
||||
def on_hci_enhanced_setup_synchronous_connection_command(
|
||||
@@ -966,11 +1151,32 @@ class Controller:
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_sco_connect(
|
||||
self,
|
||||
future = self.send_lmp_packet(
|
||||
connection.peer_address,
|
||||
hci.HCI_Connection_Complete_Event.LinkType.ESCO,
|
||||
lmp.LmpEscoLinkReq(
|
||||
esco_handle=0,
|
||||
esco_lt_addr=0,
|
||||
timing_control_flags=0,
|
||||
d_esco=0,
|
||||
t_esco=0,
|
||||
w_esco=0,
|
||||
esco_packet_type_c_to_p=0,
|
||||
esco_packet_type_p_to_c=0,
|
||||
packet_length_c_to_p=0,
|
||||
packet_length_p_to_c=0,
|
||||
air_mode=0,
|
||||
negotiation_state=0,
|
||||
),
|
||||
)
|
||||
|
||||
def on_response(future: asyncio.Future[int]):
|
||||
self.on_classic_sco_connection_complete(
|
||||
connection.peer_address,
|
||||
future.result(),
|
||||
hci.HCI_Connection_Complete_Event.LinkType.ESCO,
|
||||
)
|
||||
|
||||
future.add_done_callback(on_response)
|
||||
return None
|
||||
|
||||
def on_hci_enhanced_accept_synchronous_connection_request_command(
|
||||
@@ -1000,9 +1206,13 @@ class Controller:
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_accept_sco_connection(
|
||||
self,
|
||||
self.send_lmp_packet(
|
||||
connection.peer_address,
|
||||
lmp.LmpAcceptedExt(lmp.Opcode.LMP_ESCO_LINK_REQ),
|
||||
)
|
||||
self.on_classic_sco_connection_complete(
|
||||
connection.peer_address,
|
||||
hci.HCI_SUCCESS,
|
||||
hci.HCI_Connection_Complete_Event.LinkType.ESCO,
|
||||
)
|
||||
return None
|
||||
@@ -1083,14 +1293,52 @@ class Controller:
|
||||
|
||||
if self.link is None:
|
||||
return None
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
|
||||
if connection := self.classic_connections.get(command.bd_addr):
|
||||
current_role = connection.role
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
)
|
||||
self.link.classic_switch_role(self, command.bd_addr, command.role)
|
||||
else:
|
||||
# Connection doesn't exist, reject.
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_COMMAND_DISALLOWED_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
# If role doesn't change, only send event to local host.
|
||||
if current_role == command.role:
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Role_Change_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
bd_addr=command.bd_addr,
|
||||
new_role=current_role,
|
||||
)
|
||||
)
|
||||
else:
|
||||
future = self.send_lmp_packet(command.bd_addr, lmp.LmpSwitchReq())
|
||||
|
||||
def on_response(future: asyncio.Future[int]):
|
||||
if (status := future.result()) == hci.HCI_SUCCESS:
|
||||
connection.role = hci.Role(command.role)
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Role_Change_Event(
|
||||
status=status,
|
||||
bd_addr=command.bd_addr,
|
||||
new_role=connection.role,
|
||||
)
|
||||
)
|
||||
|
||||
future.add_done_callback(on_response)
|
||||
|
||||
return None
|
||||
|
||||
def on_hci_set_event_mask_command(
|
||||
|
||||
Reference in New Issue
Block a user