From 9c70c487b9a97c7b41482c8e97f1ef542b2c561f Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 27 Jul 2023 10:46:22 +0800 Subject: [PATCH] Add type hint to L2CAP module --- bumble/core.py | 4 +- bumble/l2cap.py | 230 +++++++++++++++++++++++++++--------------------- 2 files changed, 134 insertions(+), 100 deletions(-) diff --git a/bumble/core.py b/bumble/core.py index 2700bb04..2e3f4af7 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -17,7 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import struct -from typing import List, Optional, Tuple, Union, cast +from typing import List, Optional, Tuple, Union, cast, Dict from .company_ids import COMPANY_IDENTIFIERS @@ -53,7 +53,7 @@ def bit_flags_to_strings(bits, bit_flag_names): return names -def name_or_number(dictionary, number, width=2): +def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str: name = dictionary.get(number) if name is not None: return name diff --git a/bumble/l2cap.py b/bumble/l2cap.py index ef7fdab2..76ae5852 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -22,7 +22,7 @@ import struct from collections import deque from pyee import EventEmitter -from typing import Dict, Type +from typing import Dict, Type, List, Optional, Tuple, Callable, Any, Union, Deque from .colors import color from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError @@ -155,7 +155,7 @@ class L2CAP_PDU: ''' @staticmethod - def from_bytes(data): + def from_bytes(data) -> L2CAP_PDU: # Sanity check if len(data) < 4: raise ValueError('not enough data for L2CAP header') @@ -165,18 +165,18 @@ class L2CAP_PDU: return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload) - def to_bytes(self): + def to_bytes(self) -> bytes: header = struct.pack(' None: self.cid = cid self.payload = payload - def __bytes__(self): + def __bytes__(self) -> bytes: return self.to_bytes() - def __str__(self): + def __str__(self) -> str: return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}' @@ -188,10 +188,10 @@ class L2CAP_Control_Frame: classes: Dict[int, Type[L2CAP_Control_Frame]] = {} code = 0 - name = None + name: str @staticmethod - def from_bytes(pdu): + def from_bytes(pdu) -> L2CAP_Control_Frame: code = pdu[0] cls = L2CAP_Control_Frame.classes.get(code) @@ -216,11 +216,11 @@ class L2CAP_Control_Frame: return self @staticmethod - def code_name(code): + def code_name(code) -> str: return name_or_number(L2CAP_CONTROL_FRAME_NAMES, code) @staticmethod - def decode_configuration_options(data): + def decode_configuration_options(data) -> List[Tuple[int, bytes]]: options = [] while len(data) >= 2: value_type = data[0] @@ -232,7 +232,7 @@ class L2CAP_Control_Frame: return options @staticmethod - def encode_configuration_options(options): + def encode_configuration_options(options) -> bytes: return b''.join( [bytes([option[0], len(option[1])]) + option[1] for option in options] ) @@ -256,29 +256,29 @@ class L2CAP_Control_Frame: return inner - def __init__(self, pdu=None, **kwargs): + def __init__(self, pdu=None, **kwargs) -> None: self.identifier = kwargs.get('identifier', 0) if hasattr(self, 'fields') and kwargs: HCI_Object.init_from_fields(self, self.fields, kwargs) - if pdu is None: - data = HCI_Object.dict_to_bytes(kwargs, self.fields) - pdu = ( - bytes([self.code, self.identifier]) - + struct.pack(' bytes: return self.pdu - def __bytes__(self): + def __bytes__(self) -> bytes: return self.to_bytes() - def __str__(self): + def __str__(self) -> str: result = f'{color(self.name, "yellow")} [ID={self.identifier}]' if fields := getattr(self, 'fields', None): result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ') @@ -315,7 +315,7 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame): } @staticmethod - def reason_name(reason): + def reason_name(reason) -> str: return name_or_number(L2CAP_Command_Reject.REASON_NAMES, reason) @@ -343,7 +343,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): ''' @staticmethod - def parse_psm(data, offset=0): + def parse_psm(data, offset=0) -> Tuple[int, int]: psm_length = 2 psm = data[offset] | data[offset + 1] << 8 @@ -355,7 +355,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): return offset + psm_length, psm @staticmethod - def serialize_psm(psm): + def serialize_psm(psm) -> bytes: serialized = struct.pack('>= 16 while psm: @@ -405,7 +405,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame): } @staticmethod - def result_name(result): + def result_name(result) -> str: return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result) @@ -452,7 +452,7 @@ class L2CAP_Configure_Response(L2CAP_Control_Frame): } @staticmethod - def result_name(result): + def result_name(result) -> str: return name_or_number(L2CAP_Configure_Response.RESULT_NAMES, result) @@ -529,7 +529,7 @@ class L2CAP_Information_Request(L2CAP_Control_Frame): } @staticmethod - def info_type_name(info_type): + def info_type_name(info_type) -> str: return name_or_number(L2CAP_Information_Request.INFO_TYPE_NAMES, info_type) @@ -556,7 +556,7 @@ class L2CAP_Information_Response(L2CAP_Control_Frame): RESULT_NAMES = {SUCCESS: 'SUCCESS', NOT_SUPPORTED: 'NOT_SUPPORTED'} @staticmethod - def result_name(result): + def result_name(result) -> str: return name_or_number(L2CAP_Information_Response.RESULT_NAMES, result) @@ -588,6 +588,8 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame): (CODE 0x14) ''' + source_cid: int + # ----------------------------------------------------------------------------- @L2CAP_Control_Frame.subclass( @@ -640,7 +642,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame): } @staticmethod - def result_name(result): + def result_name(result) -> str: return name_or_number( L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result ) @@ -701,7 +703,14 @@ class Channel(EventEmitter): WAIT_CONTROL_IND: 'WAIT_CONTROL_IND', } - def __init__(self, manager, connection, signaling_cid, psm, source_cid, mtu): + connection_result: Optional[asyncio.Future[None]] + disconnection_result: Optional[asyncio.Future[None]] + response: Optional[asyncio.Future[bytes]] + sink: Optional[Callable[[bytes], Any]] + + def __init__( + self, manager, connection, signaling_cid, psm, source_cid, mtu + ) -> None: super().__init__() self.manager = manager self.connection = connection @@ -716,19 +725,19 @@ class Channel(EventEmitter): self.disconnection_result = None self.sink = None - def change_state(self, new_state): + def change_state(self, new_state) -> None: logger.debug( f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}' ) self.state = new_state - def send_pdu(self, pdu): + def send_pdu(self, pdu) -> None: self.manager.send_pdu(self.connection, self.destination_cid, pdu) - def send_control_frame(self, frame): + def send_control_frame(self, frame) -> None: self.manager.send_control_frame(self.connection, self.signaling_cid, frame) - async def send_request(self, request): + async def send_request(self, request) -> bytes: # Check that there isn't already a request pending if self.response: raise InvalidStateError('request already pending') @@ -739,7 +748,7 @@ class Channel(EventEmitter): self.send_pdu(request) return await self.response - def on_pdu(self, pdu): + def on_pdu(self, pdu) -> None: if self.response: self.response.set_result(pdu) self.response = None @@ -751,7 +760,7 @@ class Channel(EventEmitter): color('received pdu without a pending request or sink', 'red') ) - async def connect(self): + async def connect(self) -> None: if self.state != Channel.CLOSED: raise InvalidStateError('invalid state') @@ -778,7 +787,7 @@ class Channel(EventEmitter): finally: self.connection_result = None - async def disconnect(self): + async def disconnect(self) -> None: if self.state != Channel.OPEN: raise InvalidStateError('invalid state') @@ -796,12 +805,12 @@ class Channel(EventEmitter): self.disconnection_result = asyncio.get_running_loop().create_future() return await self.disconnection_result - def abort(self): + def abort(self) -> None: if self.state == self.OPEN: self.change_state(self.CLOSED) self.emit('close') - def send_configure_request(self): + def send_configure_request(self) -> None: options = L2CAP_Control_Frame.encode_configuration_options( [ ( @@ -819,7 +828,7 @@ class Channel(EventEmitter): ) ) - def on_connection_request(self, request): + def on_connection_request(self, request) -> None: self.destination_cid = request.source_cid self.change_state(Channel.WAIT_CONNECT) self.send_control_frame( @@ -858,7 +867,7 @@ class Channel(EventEmitter): ) self.connection_result = None - def on_configure_request(self, request): + def on_configure_request(self, request) -> None: if self.state not in ( Channel.WAIT_CONFIG, Channel.WAIT_CONFIG_REQ, @@ -896,7 +905,7 @@ class Channel(EventEmitter): elif self.state == Channel.WAIT_CONFIG_REQ_RSP: self.change_state(Channel.WAIT_CONFIG_RSP) - def on_configure_response(self, response): + def on_configure_response(self, response) -> None: if response.result == L2CAP_Configure_Response.SUCCESS: if self.state == Channel.WAIT_CONFIG_REQ_RSP: self.change_state(Channel.WAIT_CONFIG_REQ) @@ -930,7 +939,7 @@ class Channel(EventEmitter): ) # TODO: decide how to fail gracefully - def on_disconnection_request(self, request): + def on_disconnection_request(self, request) -> None: if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT): self.send_control_frame( L2CAP_Disconnection_Response( @@ -945,7 +954,7 @@ class Channel(EventEmitter): else: logger.warning(color('invalid state', 'red')) - def on_disconnection_response(self, response): + def on_disconnection_response(self, response) -> None: if self.state != Channel.WAIT_DISCONNECT: logger.warning(color('invalid state', 'red')) return @@ -964,7 +973,7 @@ class Channel(EventEmitter): self.emit('close') self.manager.on_channel_closed(self) - def __str__(self): + def __str__(self) -> str: return ( f'Channel({self.source_cid}->{self.destination_cid}, ' f'PSM={self.psm}, ' @@ -995,8 +1004,13 @@ class LeConnectionOrientedChannel(EventEmitter): CONNECTION_ERROR: 'CONNECTION_ERROR', } + out_queue: Deque[bytes] + connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] + disconnection_result: Optional[asyncio.Future[None]] + out_sdu: Optional[bytes] + @staticmethod - def state_name(state): + def state_name(state) -> str: return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state) def __init__( @@ -1013,7 +1027,7 @@ class LeConnectionOrientedChannel(EventEmitter): peer_mps, peer_credits, connected, - ): + ) -> None: super().__init__() self.manager = manager self.connection = connection @@ -1045,7 +1059,7 @@ class LeConnectionOrientedChannel(EventEmitter): else: self.state = LeConnectionOrientedChannel.INIT - def change_state(self, new_state): + def change_state(self, new_state) -> None: logger.debug( f'{self} state change -> {color(self.state_name(new_state), "cyan")}' ) @@ -1056,13 +1070,13 @@ class LeConnectionOrientedChannel(EventEmitter): elif new_state == self.DISCONNECTED: self.emit('close') - def send_pdu(self, pdu): + def send_pdu(self, pdu) -> None: self.manager.send_pdu(self.connection, self.destination_cid, pdu) - def send_control_frame(self, frame): + def send_control_frame(self, frame) -> None: self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) - async def connect(self): + async def connect(self) -> LeConnectionOrientedChannel: # Check that we're in the right state if self.state != self.INIT: raise InvalidStateError('not in a connectable state') @@ -1090,7 +1104,7 @@ class LeConnectionOrientedChannel(EventEmitter): # Wait for the connection to succeed or fail return await self.connection_result - async def disconnect(self): + async def disconnect(self) -> None: # Check that we're connected if self.state != self.CONNECTED: raise InvalidStateError('not connected') @@ -1110,11 +1124,11 @@ class LeConnectionOrientedChannel(EventEmitter): self.disconnection_result = asyncio.get_running_loop().create_future() return await self.disconnection_result - def abort(self): + def abort(self) -> None: if self.state == self.CONNECTED: self.change_state(self.DISCONNECTED) - def on_pdu(self, pdu): + def on_pdu(self, pdu) -> None: if self.sink is None: logger.warning('received pdu without a sink') return @@ -1180,7 +1194,7 @@ class LeConnectionOrientedChannel(EventEmitter): self.in_sdu = None self.in_sdu_length = 0 - def on_connection_response(self, response): + def on_connection_response(self, response) -> None: # Look for a matching pending response result if self.connection_result is None: logger.warning( @@ -1214,14 +1228,14 @@ class LeConnectionOrientedChannel(EventEmitter): # Cleanup self.connection_result = None - def on_credits(self, credits): # pylint: disable=redefined-builtin + def on_credits(self, credits) -> None: # pylint: disable=redefined-builtin self.credits += credits logger.debug(f'received {credits} credits, total = {self.credits}') # Try to send more data if we have any queued up self.process_output() - def on_disconnection_request(self, request): + def on_disconnection_request(self, request) -> None: self.send_control_frame( L2CAP_Disconnection_Response( identifier=request.identifier, @@ -1232,7 +1246,7 @@ class LeConnectionOrientedChannel(EventEmitter): self.change_state(self.DISCONNECTED) self.flush_output() - def on_disconnection_response(self, response): + def on_disconnection_response(self, response) -> None: if self.state != self.DISCONNECTING: logger.warning(color('invalid state', 'red')) return @@ -1249,11 +1263,11 @@ class LeConnectionOrientedChannel(EventEmitter): self.disconnection_result.set_result(None) self.disconnection_result = None - def flush_output(self): + def flush_output(self) -> None: self.out_queue.clear() self.out_sdu = None - def process_output(self): + def process_output(self) -> None: while self.credits > 0: if self.out_sdu is not None: # Finish the current SDU @@ -1296,7 +1310,7 @@ class LeConnectionOrientedChannel(EventEmitter): self.drained.set() return - def write(self, data): + def write(self, data) -> None: if self.state != self.CONNECTED: logger.warning('not connected, dropping data') return @@ -1311,18 +1325,18 @@ class LeConnectionOrientedChannel(EventEmitter): # Send what we can self.process_output() - async def drain(self): + async def drain(self) -> None: await self.drained.wait() - def pause_reading(self): + def pause_reading(self) -> None: # TODO: not implemented yet pass - def resume_reading(self): + def resume_reading(self) -> None: # TODO: not implemented yet pass - def __str__(self): + def __str__(self) -> str: return ( f'CoC({self.source_cid}->{self.destination_cid}, ' f'State={self.state_name(self.state)}, ' @@ -1335,9 +1349,19 @@ class LeConnectionOrientedChannel(EventEmitter): # ----------------------------------------------------------------------------- class ChannelManager: + identifiers: Dict[int, int] + channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]] + servers: Dict[int, Callable[[Channel], Any]] + le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]] + le_coc_servers: Dict[ + int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int] + ] + le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] + fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] + def __init__( self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU - ): + ) -> None: self._host = None self.identifiers = {} # Incrementing identifier values by connection self.channels = {} # All channels, mapped by connection and source cid @@ -1379,7 +1403,7 @@ class ChannelManager: return None @staticmethod - def find_free_br_edr_cid(channels): + def find_free_br_edr_cid(channels) -> int: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) @@ -1392,7 +1416,7 @@ class ChannelManager: raise RuntimeError('no free CID available') @staticmethod - def find_free_le_cid(channels): + def find_free_le_cid(channels) -> int: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) @@ -1405,7 +1429,7 @@ class ChannelManager: raise RuntimeError('no free CID') @staticmethod - def check_le_coc_parameters(max_credits, mtu, mps): + def check_le_coc_parameters(max_credits, mtu, mps) -> None: if ( max_credits < 1 or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS @@ -1419,19 +1443,19 @@ class ChannelManager: ): raise ValueError('MPS out of range') - def next_identifier(self, connection): + def next_identifier(self, connection) -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 self.identifiers[connection.handle] = identifier return identifier - def register_fixed_channel(self, cid, handler): + def register_fixed_channel(self, cid, handler) -> None: self.fixed_channels[cid] = handler - def deregister_fixed_channel(self, cid): + def deregister_fixed_channel(self, cid) -> None: if cid in self.fixed_channels: del self.fixed_channels[cid] - def register_server(self, psm, server): + def register_server(self, psm, server) -> int: if psm == 0: # Find a free PSM for candidate in range( @@ -1470,7 +1494,7 @@ class ChannelManager: max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, - ): + ) -> int: self.check_le_coc_parameters(max_credits, mtu, mps) if psm == 0: @@ -1498,7 +1522,7 @@ class ChannelManager: return psm - def on_disconnection(self, connection_handle, _reason): + def on_disconnection(self, connection_handle, _reason) -> None: logger.debug(f'disconnection from {connection_handle}, cleaning up channels') if connection_handle in self.channels: for _, channel in self.channels[connection_handle].items(): @@ -1511,7 +1535,7 @@ class ChannelManager: if connection_handle in self.identifiers: del self.identifiers[connection_handle] - def send_pdu(self, connection, cid, pdu): + def send_pdu(self, connection, cid, pdu) -> None: pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) logger.debug( f'{color(">>> Sending L2CAP PDU", "blue")} ' @@ -1520,14 +1544,16 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) - def on_pdu(self, connection, cid, pdu): + def on_pdu(self, connection, cid, pdu) -> None: if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): # Parse the L2CAP payload into a Control Frame object control_frame = L2CAP_Control_Frame.from_bytes(pdu) self.on_control_frame(connection, cid, control_frame) elif cid in self.fixed_channels: - self.fixed_channels[cid](connection.handle, pdu) + handler = self.fixed_channels[cid] + assert handler is not None + handler(connection.handle, pdu) else: if (channel := self.find_channel(connection.handle, cid)) is None: logger.warning( @@ -1539,7 +1565,7 @@ class ChannelManager: channel.on_pdu(pdu) - def send_control_frame(self, connection, cid, control_frame): + def send_control_frame(self, connection, cid, control_frame) -> None: logger.debug( f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) ' @@ -1547,7 +1573,7 @@ class ChannelManager: ) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) - def on_control_frame(self, connection, cid, control_frame): + def on_control_frame(self, connection, cid, control_frame) -> None: logger.debug( f'{color("<<< Received L2CAP Signaling Control Frame", "green")} ' f'on connection [0x{connection.handle:04X}] (CID={cid}) ' @@ -1584,10 +1610,10 @@ class ChannelManager: ), ) - def on_l2cap_command_reject(self, _connection, _cid, packet): + def on_l2cap_command_reject(self, _connection, _cid, packet) -> None: logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}') - def on_l2cap_connection_request(self, connection, cid, request): + def on_l2cap_connection_request(self, connection, cid, request) -> None: # Check if there's a server for this PSM server = self.servers.get(request.psm) if server: @@ -1639,7 +1665,7 @@ class ChannelManager: ), ) - def on_l2cap_connection_response(self, connection, cid, response): + def on_l2cap_connection_response(self, connection, cid, response) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1654,7 +1680,7 @@ class ChannelManager: channel.on_connection_response(response) - def on_l2cap_configure_request(self, connection, cid, request): + def on_l2cap_configure_request(self, connection, cid, request) -> None: if ( channel := self.find_channel(connection.handle, request.destination_cid) ) is None: @@ -1669,7 +1695,7 @@ class ChannelManager: channel.on_configure_request(request) - def on_l2cap_configure_response(self, connection, cid, response): + def on_l2cap_configure_response(self, connection, cid, response) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1684,7 +1710,7 @@ class ChannelManager: channel.on_configure_response(response) - def on_l2cap_disconnection_request(self, connection, cid, request): + def on_l2cap_disconnection_request(self, connection, cid, request) -> None: if ( channel := self.find_channel(connection.handle, request.destination_cid) ) is None: @@ -1699,7 +1725,7 @@ class ChannelManager: channel.on_disconnection_request(request) - def on_l2cap_disconnection_response(self, connection, cid, response): + def on_l2cap_disconnection_response(self, connection, cid, response) -> None: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: @@ -1714,7 +1740,7 @@ class ChannelManager: channel.on_disconnection_response(response) - def on_l2cap_echo_request(self, connection, cid, request): + def on_l2cap_echo_request(self, connection, cid, request) -> None: logger.debug(f'<<< Echo request: data={request.data.hex()}') self.send_control_frame( connection, @@ -1722,11 +1748,11 @@ class ChannelManager: L2CAP_Echo_Response(identifier=request.identifier, data=request.data), ) - def on_l2cap_echo_response(self, _connection, _cid, response): + def on_l2cap_echo_response(self, _connection, _cid, response) -> None: logger.debug(f'<<< Echo response: data={response.data.hex()}') # TODO notify listeners - def on_l2cap_information_request(self, connection, cid, request): + def on_l2cap_information_request(self, connection, cid, request) -> None: if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU: result = L2CAP_Information_Response.SUCCESS data = self.connectionless_mtu.to_bytes(2, 'little') @@ -1781,11 +1807,15 @@ class ChannelManager: ), ) - def on_l2cap_connection_parameter_update_response(self, connection, cid, response): + def on_l2cap_connection_parameter_update_response( + self, connection, cid, response + ) -> None: # TODO: check response pass - def on_l2cap_le_credit_based_connection_request(self, connection, cid, request): + def on_l2cap_le_credit_based_connection_request( + self, connection, cid, request + ) -> None: if request.le_psm in self.le_coc_servers: (server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] @@ -1887,7 +1917,9 @@ class ChannelManager: ), ) - def on_l2cap_le_credit_based_connection_response(self, connection, _cid, response): + def on_l2cap_le_credit_based_connection_response( + self, connection, _cid, response + ) -> None: # Find the pending request by identifier request = self.le_coc_requests.get(response.identifier) if request is None: @@ -1910,7 +1942,7 @@ class ChannelManager: # Process the response channel.on_connection_response(response) - def on_l2cap_le_flow_control_credit(self, connection, _cid, credit): + def on_l2cap_le_flow_control_credit(self, connection, _cid, credit) -> None: channel = self.find_le_coc_channel(connection.handle, credit.cid) if channel is None: logger.warning(f'received credits for an unknown channel (cid={credit.cid}') @@ -1918,13 +1950,15 @@ class ChannelManager: channel.on_credits(credit.credits) - def on_channel_closed(self, channel): + def on_channel_closed(self, channel) -> None: connection_channels = self.channels.get(channel.connection.handle) if connection_channels: if channel.source_cid in connection_channels: del connection_channels[channel.source_cid] - async def open_le_coc(self, connection, psm, max_credits, mtu, mps): + async def open_le_coc( + self, connection, psm, max_credits, mtu, mps + ) -> LeConnectionOrientedChannel: self.check_le_coc_parameters(max_credits, mtu, mps) # Find a free CID for the new channel @@ -1965,7 +1999,7 @@ class ChannelManager: return channel - async def connect(self, connection, psm): + async def connect(self, connection, psm) -> Channel: # NOTE: this implementation hard-codes BR/EDR # Find a free CID for a new channel