Add type hint to L2CAP module

This commit is contained in:
Josh Wu
2023-07-27 10:46:22 +08:00
committed by Lucas Abel
parent 43234d7c3e
commit 9c70c487b9
2 changed files with 134 additions and 100 deletions
+2 -2
View File
@@ -17,7 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast, Dict
from .company_ids import COMPANY_IDENTIFIERS
@@ -53,7 +53,7 @@ def bit_flags_to_strings(bits, bit_flag_names):
return names
def name_or_number(dictionary, number, width=2):
def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str:
name = dictionary.get(number)
if name is not None:
return name
+132 -98
View File
@@ -22,7 +22,7 @@ import struct
from collections import deque
from pyee import EventEmitter
from typing import Dict, Type
from typing import Dict, Type, List, Optional, Tuple, Callable, Any, Union, Deque
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
@@ -155,7 +155,7 @@ class L2CAP_PDU:
'''
@staticmethod
def from_bytes(data):
def from_bytes(data) -> L2CAP_PDU:
# Sanity check
if len(data) < 4:
raise ValueError('not enough data for L2CAP header')
@@ -165,18 +165,18 @@ class L2CAP_PDU:
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
def to_bytes(self):
def to_bytes(self) -> bytes:
header = struct.pack('<HH', len(self.payload), self.cid)
return header + self.payload
def __init__(self, cid, payload):
def __init__(self, cid, payload) -> None:
self.cid = cid
self.payload = payload
def __bytes__(self):
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self):
def __str__(self) -> str:
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
@@ -188,10 +188,10 @@ class L2CAP_Control_Frame:
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
code = 0
name = None
name: str
@staticmethod
def from_bytes(pdu):
def from_bytes(pdu) -> L2CAP_Control_Frame:
code = pdu[0]
cls = L2CAP_Control_Frame.classes.get(code)
@@ -216,11 +216,11 @@ class L2CAP_Control_Frame:
return self
@staticmethod
def code_name(code):
def code_name(code) -> str:
return name_or_number(L2CAP_CONTROL_FRAME_NAMES, code)
@staticmethod
def decode_configuration_options(data):
def decode_configuration_options(data) -> List[Tuple[int, bytes]]:
options = []
while len(data) >= 2:
value_type = data[0]
@@ -232,7 +232,7 @@ class L2CAP_Control_Frame:
return options
@staticmethod
def encode_configuration_options(options):
def encode_configuration_options(options) -> bytes:
return b''.join(
[bytes([option[0], len(option[1])]) + option[1] for option in options]
)
@@ -256,29 +256,29 @@ class L2CAP_Control_Frame:
return inner
def __init__(self, pdu=None, **kwargs):
def __init__(self, pdu=None, **kwargs) -> None:
self.identifier = kwargs.get('identifier', 0)
if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
bytes([self.code, self.identifier])
+ struct.pack('<H', len(data))
+ data
)
if pdu is None:
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
bytes([self.code, self.identifier])
+ struct.pack('<H', len(data))
+ data
)
self.pdu = pdu
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
def to_bytes(self) -> bytes:
return self.pdu
def __bytes__(self):
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self):
def __str__(self) -> str:
result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
if fields := getattr(self, 'fields', None):
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ')
@@ -315,7 +315,7 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame):
}
@staticmethod
def reason_name(reason):
def reason_name(reason) -> str:
return name_or_number(L2CAP_Command_Reject.REASON_NAMES, reason)
@@ -343,7 +343,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
'''
@staticmethod
def parse_psm(data, offset=0):
def parse_psm(data, offset=0) -> Tuple[int, int]:
psm_length = 2
psm = data[offset] | data[offset + 1] << 8
@@ -355,7 +355,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
return offset + psm_length, psm
@staticmethod
def serialize_psm(psm):
def serialize_psm(psm) -> bytes:
serialized = struct.pack('<H', psm & 0xFFFF)
psm >>= 16
while psm:
@@ -405,7 +405,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
}
@staticmethod
def result_name(result):
def result_name(result) -> str:
return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result)
@@ -452,7 +452,7 @@ class L2CAP_Configure_Response(L2CAP_Control_Frame):
}
@staticmethod
def result_name(result):
def result_name(result) -> str:
return name_or_number(L2CAP_Configure_Response.RESULT_NAMES, result)
@@ -529,7 +529,7 @@ class L2CAP_Information_Request(L2CAP_Control_Frame):
}
@staticmethod
def info_type_name(info_type):
def info_type_name(info_type) -> str:
return name_or_number(L2CAP_Information_Request.INFO_TYPE_NAMES, info_type)
@@ -556,7 +556,7 @@ class L2CAP_Information_Response(L2CAP_Control_Frame):
RESULT_NAMES = {SUCCESS: 'SUCCESS', NOT_SUPPORTED: 'NOT_SUPPORTED'}
@staticmethod
def result_name(result):
def result_name(result) -> str:
return name_or_number(L2CAP_Information_Response.RESULT_NAMES, result)
@@ -588,6 +588,8 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
(CODE 0x14)
'''
source_cid: int
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
@@ -640,7 +642,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
}
@staticmethod
def result_name(result):
def result_name(result) -> str:
return name_or_number(
L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result
)
@@ -701,7 +703,14 @@ class Channel(EventEmitter):
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
}
def __init__(self, manager, connection, signaling_cid, psm, source_cid, mtu):
connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
def __init__(
self, manager, connection, signaling_cid, psm, source_cid, mtu
) -> None:
super().__init__()
self.manager = manager
self.connection = connection
@@ -716,19 +725,19 @@ class Channel(EventEmitter):
self.disconnection_result = None
self.sink = None
def change_state(self, new_state):
def change_state(self, new_state) -> None:
logger.debug(
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
self.state = new_state
def send_pdu(self, pdu):
def send_pdu(self, pdu) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame):
def send_control_frame(self, frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request):
async def send_request(self, request) -> bytes:
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
@@ -739,7 +748,7 @@ class Channel(EventEmitter):
self.send_pdu(request)
return await self.response
def on_pdu(self, pdu):
def on_pdu(self, pdu) -> None:
if self.response:
self.response.set_result(pdu)
self.response = None
@@ -751,7 +760,7 @@ class Channel(EventEmitter):
color('received pdu without a pending request or sink', 'red')
)
async def connect(self):
async def connect(self) -> None:
if self.state != Channel.CLOSED:
raise InvalidStateError('invalid state')
@@ -778,7 +787,7 @@ class Channel(EventEmitter):
finally:
self.connection_result = None
async def disconnect(self):
async def disconnect(self) -> None:
if self.state != Channel.OPEN:
raise InvalidStateError('invalid state')
@@ -796,12 +805,12 @@ class Channel(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self):
def abort(self) -> None:
if self.state == self.OPEN:
self.change_state(self.CLOSED)
self.emit('close')
def send_configure_request(self):
def send_configure_request(self) -> None:
options = L2CAP_Control_Frame.encode_configuration_options(
[
(
@@ -819,7 +828,7 @@ class Channel(EventEmitter):
)
)
def on_connection_request(self, request):
def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid
self.change_state(Channel.WAIT_CONNECT)
self.send_control_frame(
@@ -858,7 +867,7 @@ class Channel(EventEmitter):
)
self.connection_result = None
def on_configure_request(self, request):
def on_configure_request(self, request) -> None:
if self.state not in (
Channel.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ,
@@ -896,7 +905,7 @@ class Channel(EventEmitter):
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_RSP)
def on_configure_response(self, response):
def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ)
@@ -930,7 +939,7 @@ class Channel(EventEmitter):
)
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request):
def on_disconnection_request(self, request) -> None:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
@@ -945,7 +954,7 @@ class Channel(EventEmitter):
else:
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response):
def on_disconnection_response(self, response) -> None:
if self.state != Channel.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red'))
return
@@ -964,7 +973,7 @@ class Channel(EventEmitter):
self.emit('close')
self.manager.on_channel_closed(self)
def __str__(self):
def __str__(self) -> str:
return (
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
@@ -995,8 +1004,13 @@ class LeConnectionOrientedChannel(EventEmitter):
CONNECTION_ERROR: 'CONNECTION_ERROR',
}
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
@staticmethod
def state_name(state):
def state_name(state) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__(
@@ -1013,7 +1027,7 @@ class LeConnectionOrientedChannel(EventEmitter):
peer_mps,
peer_credits,
connected,
):
) -> None:
super().__init__()
self.manager = manager
self.connection = connection
@@ -1045,7 +1059,7 @@ class LeConnectionOrientedChannel(EventEmitter):
else:
self.state = LeConnectionOrientedChannel.INIT
def change_state(self, new_state):
def change_state(self, new_state) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
@@ -1056,13 +1070,13 @@ class LeConnectionOrientedChannel(EventEmitter):
elif new_state == self.DISCONNECTED:
self.emit('close')
def send_pdu(self, pdu):
def send_pdu(self, pdu) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame):
def send_control_frame(self, frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self):
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
if self.state != self.INIT:
raise InvalidStateError('not in a connectable state')
@@ -1090,7 +1104,7 @@ class LeConnectionOrientedChannel(EventEmitter):
# Wait for the connection to succeed or fail
return await self.connection_result
async def disconnect(self):
async def disconnect(self) -> None:
# Check that we're connected
if self.state != self.CONNECTED:
raise InvalidStateError('not connected')
@@ -1110,11 +1124,11 @@ class LeConnectionOrientedChannel(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self):
def abort(self) -> None:
if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED)
def on_pdu(self, pdu):
def on_pdu(self, pdu) -> None:
if self.sink is None:
logger.warning('received pdu without a sink')
return
@@ -1180,7 +1194,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.in_sdu = None
self.in_sdu_length = 0
def on_connection_response(self, response):
def on_connection_response(self, response) -> None:
# Look for a matching pending response result
if self.connection_result is None:
logger.warning(
@@ -1214,14 +1228,14 @@ class LeConnectionOrientedChannel(EventEmitter):
# Cleanup
self.connection_result = None
def on_credits(self, credits): # pylint: disable=redefined-builtin
def on_credits(self, credits) -> None: # pylint: disable=redefined-builtin
self.credits += credits
logger.debug(f'received {credits} credits, total = {self.credits}')
# Try to send more data if we have any queued up
self.process_output()
def on_disconnection_request(self, request):
def on_disconnection_request(self, request) -> None:
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -1232,7 +1246,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.change_state(self.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response):
def on_disconnection_response(self, response) -> None:
if self.state != self.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
@@ -1249,11 +1263,11 @@ class LeConnectionOrientedChannel(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def flush_output(self):
def flush_output(self) -> None:
self.out_queue.clear()
self.out_sdu = None
def process_output(self):
def process_output(self) -> None:
while self.credits > 0:
if self.out_sdu is not None:
# Finish the current SDU
@@ -1296,7 +1310,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set()
return
def write(self, data):
def write(self, data) -> None:
if self.state != self.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1311,18 +1325,18 @@ class LeConnectionOrientedChannel(EventEmitter):
# Send what we can
self.process_output()
async def drain(self):
async def drain(self) -> None:
await self.drained.wait()
def pause_reading(self):
def pause_reading(self) -> None:
# TODO: not implemented yet
pass
def resume_reading(self):
def resume_reading(self) -> None:
# TODO: not implemented yet
pass
def __str__(self):
def __str__(self) -> str:
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, '
@@ -1335,9 +1349,19 @@ class LeConnectionOrientedChannel(EventEmitter):
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
servers: Dict[int, Callable[[Channel], Any]]
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
le_coc_servers: Dict[
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
def __init__(
self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU
):
) -> None:
self._host = None
self.identifiers = {} # Incrementing identifier values by connection
self.channels = {} # All channels, mapped by connection and source cid
@@ -1379,7 +1403,7 @@ class ChannelManager:
return None
@staticmethod
def find_free_br_edr_cid(channels):
def find_free_br_edr_cid(channels) -> int:
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
@@ -1392,7 +1416,7 @@ class ChannelManager:
raise RuntimeError('no free CID available')
@staticmethod
def find_free_le_cid(channels):
def find_free_le_cid(channels) -> int:
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
@@ -1405,7 +1429,7 @@ class ChannelManager:
raise RuntimeError('no free CID')
@staticmethod
def check_le_coc_parameters(max_credits, mtu, mps):
def check_le_coc_parameters(max_credits, mtu, mps) -> None:
if (
max_credits < 1
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
@@ -1419,19 +1443,19 @@ class ChannelManager:
):
raise ValueError('MPS out of range')
def next_identifier(self, connection):
def next_identifier(self, connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
return identifier
def register_fixed_channel(self, cid, handler):
def register_fixed_channel(self, cid, handler) -> None:
self.fixed_channels[cid] = handler
def deregister_fixed_channel(self, cid):
def deregister_fixed_channel(self, cid) -> None:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
def register_server(self, psm, server):
def register_server(self, psm, server) -> int:
if psm == 0:
# Find a free PSM
for candidate in range(
@@ -1470,7 +1494,7 @@ class ChannelManager:
max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
):
) -> int:
self.check_le_coc_parameters(max_credits, mtu, mps)
if psm == 0:
@@ -1498,7 +1522,7 @@ class ChannelManager:
return psm
def on_disconnection(self, connection_handle, _reason):
def on_disconnection(self, connection_handle, _reason) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
if connection_handle in self.channels:
for _, channel in self.channels[connection_handle].items():
@@ -1511,7 +1535,7 @@ class ChannelManager:
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
def send_pdu(self, connection, cid, pdu):
def send_pdu(self, connection, cid, pdu) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
@@ -1520,14 +1544,16 @@ class ChannelManager:
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection, cid, pdu):
def on_pdu(self, connection, cid, pdu) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
self.on_control_frame(connection, cid, control_frame)
elif cid in self.fixed_channels:
self.fixed_channels[cid](connection.handle, pdu)
handler = self.fixed_channels[cid]
assert handler is not None
handler(connection.handle, pdu)
else:
if (channel := self.find_channel(connection.handle, cid)) is None:
logger.warning(
@@ -1539,7 +1565,7 @@ class ChannelManager:
channel.on_pdu(pdu)
def send_control_frame(self, connection, cid, control_frame):
def send_control_frame(self, connection, cid, control_frame) -> None:
logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1547,7 +1573,7 @@ class ChannelManager:
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
def on_control_frame(self, connection, cid, control_frame):
def on_control_frame(self, connection, cid, control_frame) -> None:
logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1584,10 +1610,10 @@ class ChannelManager:
),
)
def on_l2cap_command_reject(self, _connection, _cid, packet):
def on_l2cap_command_reject(self, _connection, _cid, packet) -> None:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(self, connection, cid, request):
def on_l2cap_connection_request(self, connection, cid, request) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
@@ -1639,7 +1665,7 @@ class ChannelManager:
),
)
def on_l2cap_connection_response(self, connection, cid, response):
def on_l2cap_connection_response(self, connection, cid, response) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1654,7 +1680,7 @@ class ChannelManager:
channel.on_connection_response(response)
def on_l2cap_configure_request(self, connection, cid, request):
def on_l2cap_configure_request(self, connection, cid, request) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1669,7 +1695,7 @@ class ChannelManager:
channel.on_configure_request(request)
def on_l2cap_configure_response(self, connection, cid, response):
def on_l2cap_configure_response(self, connection, cid, response) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1684,7 +1710,7 @@ class ChannelManager:
channel.on_configure_response(response)
def on_l2cap_disconnection_request(self, connection, cid, request):
def on_l2cap_disconnection_request(self, connection, cid, request) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1699,7 +1725,7 @@ class ChannelManager:
channel.on_disconnection_request(request)
def on_l2cap_disconnection_response(self, connection, cid, response):
def on_l2cap_disconnection_response(self, connection, cid, response) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1714,7 +1740,7 @@ class ChannelManager:
channel.on_disconnection_response(response)
def on_l2cap_echo_request(self, connection, cid, request):
def on_l2cap_echo_request(self, connection, cid, request) -> None:
logger.debug(f'<<< Echo request: data={request.data.hex()}')
self.send_control_frame(
connection,
@@ -1722,11 +1748,11 @@ class ChannelManager:
L2CAP_Echo_Response(identifier=request.identifier, data=request.data),
)
def on_l2cap_echo_response(self, _connection, _cid, response):
def on_l2cap_echo_response(self, _connection, _cid, response) -> None:
logger.debug(f'<<< Echo response: data={response.data.hex()}')
# TODO notify listeners
def on_l2cap_information_request(self, connection, cid, request):
def on_l2cap_information_request(self, connection, cid, request) -> None:
if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU:
result = L2CAP_Information_Response.SUCCESS
data = self.connectionless_mtu.to_bytes(2, 'little')
@@ -1781,11 +1807,15 @@ class ChannelManager:
),
)
def on_l2cap_connection_parameter_update_response(self, connection, cid, response):
def on_l2cap_connection_parameter_update_response(
self, connection, cid, response
) -> None:
# TODO: check response
pass
def on_l2cap_le_credit_based_connection_request(self, connection, cid, request):
def on_l2cap_le_credit_based_connection_request(
self, connection, cid, request
) -> None:
if request.le_psm in self.le_coc_servers:
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
@@ -1887,7 +1917,9 @@ class ChannelManager:
),
)
def on_l2cap_le_credit_based_connection_response(self, connection, _cid, response):
def on_l2cap_le_credit_based_connection_response(
self, connection, _cid, response
) -> None:
# Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier)
if request is None:
@@ -1910,7 +1942,7 @@ class ChannelManager:
# Process the response
channel.on_connection_response(response)
def on_l2cap_le_flow_control_credit(self, connection, _cid, credit):
def on_l2cap_le_flow_control_credit(self, connection, _cid, credit) -> None:
channel = self.find_le_coc_channel(connection.handle, credit.cid)
if channel is None:
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
@@ -1918,13 +1950,15 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel):
def on_channel_closed(self, channel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
async def open_le_coc(self, connection, psm, max_credits, mtu, mps):
async def open_le_coc(
self, connection, psm, max_credits, mtu, mps
) -> LeConnectionOrientedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps)
# Find a free CID for the new channel
@@ -1965,7 +1999,7 @@ class ChannelManager:
return channel
async def connect(self, connection, psm):
async def connect(self, connection, psm) -> Channel:
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel