From 46eb81466d60cb04d999e364cbb7c86eb9b8a483 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 27 Jul 2023 15:26:02 +0800 Subject: [PATCH] Add more argement hints in L2CAP --- bumble/l2cap.py | 170 ++++++++++++++++++++++++++++-------------------- 1 file changed, 98 insertions(+), 72 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 76ae585..119cb13 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -22,7 +22,18 @@ import struct from collections import deque from pyee import EventEmitter -from typing import Dict, Type, List, Optional, Tuple, Callable, Any, Union, Deque +from typing import ( + Dict, + Type, + List, + Optional, + Tuple, + Callable, + Any, + Union, + Deque, + Iterable, +) from .colors import color from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError @@ -155,7 +166,7 @@ class L2CAP_PDU: ''' @staticmethod - def from_bytes(data) -> L2CAP_PDU: + def from_bytes(data: bytes) -> L2CAP_PDU: # Sanity check if len(data) < 4: raise ValueError('not enough data for L2CAP header') @@ -169,7 +180,7 @@ class L2CAP_PDU: header = struct.pack(' None: + def __init__(self, cid: int, payload: bytes) -> None: self.cid = cid self.payload = payload @@ -191,7 +202,7 @@ class L2CAP_Control_Frame: name: str @staticmethod - def from_bytes(pdu) -> L2CAP_Control_Frame: + def from_bytes(pdu: bytes) -> L2CAP_Control_Frame: code = pdu[0] cls = L2CAP_Control_Frame.classes.get(code) @@ -216,11 +227,11 @@ class L2CAP_Control_Frame: return self @staticmethod - def code_name(code) -> str: + def code_name(code: int) -> str: return name_or_number(L2CAP_CONTROL_FRAME_NAMES, code) @staticmethod - def decode_configuration_options(data) -> List[Tuple[int, bytes]]: + def decode_configuration_options(data: bytes) -> List[Tuple[int, bytes]]: options = [] while len(data) >= 2: value_type = data[0] @@ -232,7 +243,7 @@ class L2CAP_Control_Frame: return options @staticmethod - def encode_configuration_options(options) -> bytes: + def encode_configuration_options(options: List[Tuple[int, bytes]]) -> bytes: return b''.join( [bytes([option[0], len(option[1])]) + option[1] for option in options] ) @@ -258,8 +269,9 @@ class L2CAP_Control_Frame: def __init__(self, pdu=None, **kwargs) -> None: self.identifier = kwargs.get('identifier', 0) - if hasattr(self, 'fields') and kwargs: - HCI_Object.init_from_fields(self, self.fields, kwargs) + if hasattr(self, 'fields'): + if kwargs: + HCI_Object.init_from_fields(self, self.fields, kwargs) if pdu is None: data = HCI_Object.dict_to_bytes(kwargs, self.fields) pdu = ( @@ -315,7 +327,7 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame): } @staticmethod - def reason_name(reason) -> str: + def reason_name(reason: int) -> str: return name_or_number(L2CAP_Command_Reject.REASON_NAMES, reason) @@ -343,7 +355,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): ''' @staticmethod - def parse_psm(data, offset=0) -> Tuple[int, int]: + def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]: psm_length = 2 psm = data[offset] | data[offset + 1] << 8 @@ -355,7 +367,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): return offset + psm_length, psm @staticmethod - def serialize_psm(psm) -> bytes: + def serialize_psm(psm: int) -> bytes: serialized = struct.pack('>= 16 while psm: @@ -405,7 +417,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame): } @staticmethod - def result_name(result) -> str: + def result_name(result: int) -> str: return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result) @@ -452,7 +464,7 @@ class L2CAP_Configure_Response(L2CAP_Control_Frame): } @staticmethod - def result_name(result) -> str: + def result_name(result: int) -> str: return name_or_number(L2CAP_Configure_Response.RESULT_NAMES, result) @@ -529,7 +541,7 @@ class L2CAP_Information_Request(L2CAP_Control_Frame): } @staticmethod - def info_type_name(info_type) -> str: + def info_type_name(info_type: int) -> str: return name_or_number(L2CAP_Information_Request.INFO_TYPE_NAMES, info_type) @@ -556,7 +568,7 @@ class L2CAP_Information_Response(L2CAP_Control_Frame): RESULT_NAMES = {SUCCESS: 'SUCCESS', NOT_SUPPORTED: 'NOT_SUPPORTED'} @staticmethod - def result_name(result) -> str: + def result_name(result: int) -> str: return name_or_number(L2CAP_Information_Response.RESULT_NAMES, result) @@ -642,7 +654,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame): } @staticmethod - def result_name(result) -> str: + def result_name(result: int) -> str: return name_or_number( L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result ) @@ -707,9 +719,16 @@ class Channel(EventEmitter): disconnection_result: Optional[asyncio.Future[None]] response: Optional[asyncio.Future[bytes]] sink: Optional[Callable[[bytes], Any]] + state: int def __init__( - self, manager, connection, signaling_cid, psm, source_cid, mtu + self, + manager: 'ChannelManager', + connection, + signaling_cid: int, + psm: int, + source_cid: int, + mtu: int, ) -> None: super().__init__() self.manager = manager @@ -725,7 +744,7 @@ class Channel(EventEmitter): self.disconnection_result = None self.sink = None - def change_state(self, new_state) -> None: + def change_state(self, new_state: int) -> None: logger.debug( f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}' ) @@ -1008,25 +1027,26 @@ class LeConnectionOrientedChannel(EventEmitter): connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] disconnection_result: Optional[asyncio.Future[None]] out_sdu: Optional[bytes] + state: int @staticmethod - def state_name(state) -> str: + def state_name(state: int) -> str: return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state) def __init__( self, - manager, + manager: 'ChannelManager', connection, - le_psm, - source_cid, - destination_cid, - mtu, - mps, - credits, # pylint: disable=redefined-builtin - peer_mtu, - peer_mps, - peer_credits, - connected, + le_psm: int, + source_cid: int, + destination_cid: int, + mtu: int, + mps: int, + credits: int, # pylint: disable=redefined-builtin + peer_mtu: int, + peer_mps: int, + peer_credits: int, + connected: bool, ) -> None: super().__init__() self.manager = manager @@ -1059,7 +1079,7 @@ class LeConnectionOrientedChannel(EventEmitter): else: self.state = LeConnectionOrientedChannel.INIT - def change_state(self, new_state) -> None: + def change_state(self, new_state: int) -> None: logger.debug( f'{self} state change -> {color(self.state_name(new_state), "cyan")}' ) @@ -1228,7 +1248,7 @@ class LeConnectionOrientedChannel(EventEmitter): # Cleanup self.connection_result = None - def on_credits(self, credits) -> None: # pylint: disable=redefined-builtin + def on_credits(self, credits: int) -> None: # pylint: disable=redefined-builtin self.credits += credits logger.debug(f'received {credits} credits, total = {self.credits}') @@ -1310,7 +1330,7 @@ class LeConnectionOrientedChannel(EventEmitter): self.drained.set() return - def write(self, data) -> None: + def write(self, data: bytes) -> None: if self.state != self.CONNECTED: logger.warning('not connected, dropping data') return @@ -1360,7 +1380,9 @@ class ChannelManager: fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] def __init__( - self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU + self, + extended_features: Iterable[int] = (), + connectionless_mtu: int = L2CAP_DEFAULT_CONNECTIONLESS_MTU, ) -> None: self._host = None self.identifiers = {} # Incrementing identifier values by connection @@ -1390,20 +1412,20 @@ class ChannelManager: if host is not None: host.on('disconnection', self.on_disconnection) - def find_channel(self, connection_handle, cid): + def find_channel(self, connection_handle: int, cid: int): if connection_channels := self.channels.get(connection_handle): return connection_channels.get(cid) return None - def find_le_coc_channel(self, connection_handle, cid): + def find_le_coc_channel(self, connection_handle: int, cid: int): if connection_channels := self.le_coc_channels.get(connection_handle): return connection_channels.get(cid) return None @staticmethod - def find_free_br_edr_cid(channels) -> int: + def find_free_br_edr_cid(channels: Iterable[int]) -> int: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) @@ -1416,7 +1438,7 @@ class ChannelManager: raise RuntimeError('no free CID available') @staticmethod - def find_free_le_cid(channels) -> int: + def find_free_le_cid(channels: Iterable[int]) -> int: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) @@ -1429,7 +1451,7 @@ class ChannelManager: raise RuntimeError('no free CID') @staticmethod - def check_le_coc_parameters(max_credits, mtu, mps) -> None: + def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None: if ( max_credits < 1 or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS @@ -1448,14 +1470,16 @@ class ChannelManager: self.identifiers[connection.handle] = identifier return identifier - def register_fixed_channel(self, cid, handler) -> None: + def register_fixed_channel( + self, cid: int, handler: Callable[[int, bytes], Any] + ) -> None: self.fixed_channels[cid] = handler - def deregister_fixed_channel(self, cid) -> None: + def deregister_fixed_channel(self, cid: int) -> None: if cid in self.fixed_channels: del self.fixed_channels[cid] - def register_server(self, psm, server) -> int: + def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int: if psm == 0: # Find a free PSM for candidate in range( @@ -1489,11 +1513,11 @@ class ChannelManager: def register_le_coc_server( self, - psm, - server, - max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, - mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + psm: int, + server: Callable[[LeConnectionOrientedChannel], Any], + max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, + mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, ) -> int: self.check_le_coc_parameters(max_credits, mtu, mps) @@ -1522,7 +1546,7 @@ class ChannelManager: return psm - def on_disconnection(self, connection_handle, _reason) -> None: + def on_disconnection(self, connection_handle: int, _reason: int) -> None: logger.debug(f'disconnection from {connection_handle}, cleaning up channels') if connection_handle in self.channels: for _, channel in self.channels[connection_handle].items(): @@ -1535,7 +1559,7 @@ class ChannelManager: if connection_handle in self.identifiers: del self.identifiers[connection_handle] - def send_pdu(self, connection, cid, pdu) -> None: + def send_pdu(self, connection, cid: int, pdu) -> None: pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) logger.debug( f'{color(">>> Sending L2CAP PDU", "blue")} ' @@ -1544,7 +1568,7 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) - def on_pdu(self, connection, cid, pdu) -> None: + def on_pdu(self, connection, cid: int, pdu) -> None: if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): # Parse the L2CAP payload into a Control Frame object control_frame = L2CAP_Control_Frame.from_bytes(pdu) @@ -1565,7 +1589,7 @@ class ChannelManager: channel.on_pdu(pdu) - def send_control_frame(self, connection, cid, control_frame) -> None: + def send_control_frame(self, connection, cid: int, control_frame) -> None: logger.debug( f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) ' @@ -1573,7 +1597,7 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) - def on_control_frame(self, connection, cid, control_frame) -> None: + def on_control_frame(self, connection, cid: int, control_frame) -> None: logger.debug( f'{color("<<< Received L2CAP Signaling Control Frame", "green")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) ' @@ -1610,10 +1634,10 @@ class ChannelManager: ), ) - def on_l2cap_command_reject(self, _connection, _cid, packet) -> None: + def on_l2cap_command_reject(self, _connection, _cid: int, packet) -> None: logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}') - def on_l2cap_connection_request(self, connection, cid, request) -> None: + def on_l2cap_connection_request(self, connection, cid: int, request) -> None: # Check if there's a server for this PSM server = self.servers.get(request.psm) if server: @@ -1665,7 +1689,7 @@ class ChannelManager: ), ) - def on_l2cap_connection_response(self, connection, cid, response) -> None: + def on_l2cap_connection_response(self, connection, cid: int, response) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1680,7 +1704,7 @@ class ChannelManager: channel.on_connection_response(response) - def on_l2cap_configure_request(self, connection, cid, request) -> None: + def on_l2cap_configure_request(self, connection, cid: int, request) -> None: if ( channel := self.find_channel(connection.handle, request.destination_cid) ) is None: @@ -1695,7 +1719,7 @@ class ChannelManager: channel.on_configure_request(request) - def on_l2cap_configure_response(self, connection, cid, response) -> None: + def on_l2cap_configure_response(self, connection, cid: int, response) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1710,7 +1734,7 @@ class ChannelManager: channel.on_configure_response(response) - def on_l2cap_disconnection_request(self, connection, cid, request) -> None: + def on_l2cap_disconnection_request(self, connection, cid: int, request) -> None: if ( channel := self.find_channel(connection.handle, request.destination_cid) ) is None: @@ -1725,7 +1749,7 @@ class ChannelManager: channel.on_disconnection_request(request) - def on_l2cap_disconnection_response(self, connection, cid, response) -> None: + def on_l2cap_disconnection_response(self, connection, cid: int, response) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1740,7 +1764,7 @@ class ChannelManager: channel.on_disconnection_response(response) - def on_l2cap_echo_request(self, connection, cid, request) -> None: + def on_l2cap_echo_request(self, connection, cid: int, request) -> None: logger.debug(f'<<< Echo request: data={request.data.hex()}') self.send_control_frame( connection, @@ -1748,11 +1772,11 @@ class ChannelManager: L2CAP_Echo_Response(identifier=request.identifier, data=request.data), ) - def on_l2cap_echo_response(self, _connection, _cid, response) -> None: + def on_l2cap_echo_response(self, _connection, _cid: int, response) -> None: logger.debug(f'<<< Echo response: data={response.data.hex()}') # TODO notify listeners - def on_l2cap_information_request(self, connection, cid, request) -> None: + def on_l2cap_information_request(self, connection, cid: int, request) -> None: if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU: result = L2CAP_Information_Response.SUCCESS data = self.connectionless_mtu.to_bytes(2, 'little') @@ -1776,7 +1800,9 @@ class ChannelManager: ), ) - def on_l2cap_connection_parameter_update_request(self, connection, cid, request): + def on_l2cap_connection_parameter_update_request( + self, connection, cid: int, request + ): if connection.role == BT_CENTRAL_ROLE: self.send_control_frame( connection, @@ -1795,7 +1821,7 @@ class ChannelManager: supervision_timeout=request.timeout, min_ce_length=0, max_ce_length=0, - ) + ) # type: ignore[call-arg] ) else: self.send_control_frame( @@ -1808,13 +1834,13 @@ class ChannelManager: ) def on_l2cap_connection_parameter_update_response( - self, connection, cid, response + self, connection, cid: int, response ) -> None: # TODO: check response pass def on_l2cap_le_credit_based_connection_request( - self, connection, cid, request + self, connection, cid: int, request ) -> None: if request.le_psm in self.le_coc_servers: (server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] @@ -1918,7 +1944,7 @@ class ChannelManager: ) def on_l2cap_le_credit_based_connection_response( - self, connection, _cid, response + self, connection, _cid: int, response ) -> None: # Find the pending request by identifier request = self.le_coc_requests.get(response.identifier) @@ -1942,7 +1968,7 @@ class ChannelManager: # Process the response channel.on_connection_response(response) - def on_l2cap_le_flow_control_credit(self, connection, _cid, credit) -> None: + def on_l2cap_le_flow_control_credit(self, connection, _cid: int, credit) -> None: channel = self.find_le_coc_channel(connection.handle, credit.cid) if channel is None: logger.warning(f'received credits for an unknown channel (cid={credit.cid}') @@ -1950,14 +1976,14 @@ class ChannelManager: channel.on_credits(credit.credits) - def on_channel_closed(self, channel) -> None: + def on_channel_closed(self, channel: Channel) -> None: connection_channels = self.channels.get(channel.connection.handle) if connection_channels: if channel.source_cid in connection_channels: del connection_channels[channel.source_cid] async def open_le_coc( - self, connection, psm, max_credits, mtu, mps + self, connection, psm: int, max_credits: int, mtu: int, mps: int ) -> LeConnectionOrientedChannel: self.check_le_coc_parameters(max_credits, mtu, mps) @@ -1999,7 +2025,7 @@ class ChannelManager: return channel - async def connect(self, connection, psm) -> Channel: + async def connect(self, connection, psm: int) -> Channel: # NOTE: this implementation hard-codes BR/EDR # Find a free CID for a new channel