L2CAP: Import device.Connection for typing

This commit is contained in:
Josh Wu
2023-07-27 15:31:08 +08:00
committed by Lucas Abel
parent 46eb81466d
commit 190529184e
+55 -23
View File
@@ -33,6 +33,7 @@ from typing import (
Union,
Deque,
Iterable,
TYPE_CHECKING,
)
from .colors import color
@@ -44,6 +45,9 @@ from .hci import (
name_or_number,
)
if TYPE_CHECKING:
from bumble.device import Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -720,11 +724,12 @@ class Channel(EventEmitter):
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
state: int
connection: 'Connection'
def __init__(
self,
manager: 'ChannelManager',
connection,
connection: 'Connection',
signaling_cid: int,
psm: int,
source_cid: int,
@@ -1028,6 +1033,7 @@ class LeConnectionOrientedChannel(EventEmitter):
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
state: int
connection: 'Connection'
@staticmethod
def state_name(state: int) -> str:
@@ -1036,7 +1042,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __init__(
self,
manager: 'ChannelManager',
connection,
connection: 'Connection',
le_psm: int,
source_cid: int,
destination_cid: int,
@@ -1465,7 +1471,7 @@ class ChannelManager:
):
raise ValueError('MPS out of range')
def next_identifier(self, connection) -> int:
def next_identifier(self, connection: 'Connection') -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
return identifier
@@ -1568,7 +1574,7 @@ class ChannelManager:
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection, cid: int, pdu) -> None:
def on_pdu(self, connection: 'Connection', cid: int, pdu) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
@@ -1589,7 +1595,9 @@ class ChannelManager:
channel.on_pdu(pdu)
def send_control_frame(self, connection, cid: int, control_frame) -> None:
def send_control_frame(
self, connection: 'Connection', cid: int, control_frame
) -> None:
logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1597,7 +1605,9 @@ class ChannelManager:
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
def on_control_frame(self, connection, cid: int, control_frame) -> None:
def on_control_frame(
self, connection: 'Connection', cid: int, control_frame
) -> None:
logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1634,10 +1644,14 @@ class ChannelManager:
),
)
def on_l2cap_command_reject(self, _connection, _cid: int, packet) -> None:
def on_l2cap_command_reject(
self, _connection: 'Connection', _cid: int, packet
) -> None:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(self, connection, cid: int, request) -> None:
def on_l2cap_connection_request(
self, connection: 'Connection', cid: int, request
) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
@@ -1689,7 +1703,9 @@ class ChannelManager:
),
)
def on_l2cap_connection_response(self, connection, cid: int, response) -> None:
def on_l2cap_connection_response(
self, connection: 'Connection', cid: int, response
) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1704,7 +1720,9 @@ class ChannelManager:
channel.on_connection_response(response)
def on_l2cap_configure_request(self, connection, cid: int, request) -> None:
def on_l2cap_configure_request(
self, connection: 'Connection', cid: int, request
) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1719,7 +1737,9 @@ class ChannelManager:
channel.on_configure_request(request)
def on_l2cap_configure_response(self, connection, cid: int, response) -> None:
def on_l2cap_configure_response(
self, connection: 'Connection', cid: int, response
) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1734,7 +1754,9 @@ class ChannelManager:
channel.on_configure_response(response)
def on_l2cap_disconnection_request(self, connection, cid: int, request) -> None:
def on_l2cap_disconnection_request(
self, connection: 'Connection', cid: int, request
) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1749,7 +1771,9 @@ class ChannelManager:
channel.on_disconnection_request(request)
def on_l2cap_disconnection_response(self, connection, cid: int, response) -> None:
def on_l2cap_disconnection_response(
self, connection: 'Connection', cid: int, response
) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1764,7 +1788,9 @@ class ChannelManager:
channel.on_disconnection_response(response)
def on_l2cap_echo_request(self, connection, cid: int, request) -> None:
def on_l2cap_echo_request(
self, connection: 'Connection', cid: int, request
) -> None:
logger.debug(f'<<< Echo request: data={request.data.hex()}')
self.send_control_frame(
connection,
@@ -1772,11 +1798,15 @@ class ChannelManager:
L2CAP_Echo_Response(identifier=request.identifier, data=request.data),
)
def on_l2cap_echo_response(self, _connection, _cid: int, response) -> None:
def on_l2cap_echo_response(
self, _connection: 'Connection', _cid: int, response
) -> None:
logger.debug(f'<<< Echo response: data={response.data.hex()}')
# TODO notify listeners
def on_l2cap_information_request(self, connection, cid: int, request) -> None:
def on_l2cap_information_request(
self, connection: 'Connection', cid: int, request
) -> None:
if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU:
result = L2CAP_Information_Response.SUCCESS
data = self.connectionless_mtu.to_bytes(2, 'little')
@@ -1801,7 +1831,7 @@ class ChannelManager:
)
def on_l2cap_connection_parameter_update_request(
self, connection, cid: int, request
self, connection: 'Connection', cid: int, request
):
if connection.role == BT_CENTRAL_ROLE:
self.send_control_frame(
@@ -1834,13 +1864,13 @@ class ChannelManager:
)
def on_l2cap_connection_parameter_update_response(
self, connection, cid: int, response
self, connection: 'Connection', cid: int, response
) -> None:
# TODO: check response
pass
def on_l2cap_le_credit_based_connection_request(
self, connection, cid: int, request
self, connection: 'Connection', cid: int, request
) -> None:
if request.le_psm in self.le_coc_servers:
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
@@ -1944,7 +1974,7 @@ class ChannelManager:
)
def on_l2cap_le_credit_based_connection_response(
self, connection, _cid: int, response
self, connection: 'Connection', _cid: int, response
) -> None:
# Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier)
@@ -1968,7 +1998,9 @@ class ChannelManager:
# Process the response
channel.on_connection_response(response)
def on_l2cap_le_flow_control_credit(self, connection, _cid: int, credit) -> None:
def on_l2cap_le_flow_control_credit(
self, connection: 'Connection', _cid: int, credit
) -> None:
channel = self.find_le_coc_channel(connection.handle, credit.cid)
if channel is None:
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
@@ -1983,7 +2015,7 @@ class ChannelManager:
del connection_channels[channel.source_cid]
async def open_le_coc(
self, connection, psm: int, max_credits: int, mtu: int, mps: int
self, connection: 'Connection', psm: int, max_credits: int, mtu: int, mps: int
) -> LeConnectionOrientedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps)
@@ -2025,7 +2057,7 @@ class ChannelManager:
return channel
async def connect(self, connection, psm: int) -> Channel:
async def connect(self, connection: 'Connection', psm: int) -> Channel:
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel