forked from auracaster/bumble_mirror
Add type hint to L2CAP module
This commit is contained in:
+2
-2
@@ -17,7 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import struct
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import List, Optional, Tuple, Union, cast, Dict
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
|
||||
@@ -53,7 +53,7 @@ def bit_flags_to_strings(bits, bit_flag_names):
|
||||
return names
|
||||
|
||||
|
||||
def name_or_number(dictionary, number, width=2):
|
||||
def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str:
|
||||
name = dictionary.get(number)
|
||||
if name is not None:
|
||||
return name
|
||||
|
||||
+132
-98
@@ -22,7 +22,7 @@ import struct
|
||||
|
||||
from collections import deque
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Type, List, Optional, Tuple, Callable, Any, Union, Deque
|
||||
|
||||
from .colors import color
|
||||
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
|
||||
@@ -155,7 +155,7 @@ class L2CAP_PDU:
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data) -> L2CAP_PDU:
|
||||
# Sanity check
|
||||
if len(data) < 4:
|
||||
raise ValueError('not enough data for L2CAP header')
|
||||
@@ -165,18 +165,18 @@ class L2CAP_PDU:
|
||||
|
||||
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
|
||||
|
||||
def to_bytes(self):
|
||||
def to_bytes(self) -> bytes:
|
||||
header = struct.pack('<HH', len(self.payload), self.cid)
|
||||
return header + self.payload
|
||||
|
||||
def __init__(self, cid, payload):
|
||||
def __init__(self, cid, payload) -> None:
|
||||
self.cid = cid
|
||||
self.payload = payload
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
|
||||
|
||||
|
||||
@@ -188,10 +188,10 @@ class L2CAP_Control_Frame:
|
||||
|
||||
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
|
||||
code = 0
|
||||
name = None
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu):
|
||||
def from_bytes(pdu) -> L2CAP_Control_Frame:
|
||||
code = pdu[0]
|
||||
|
||||
cls = L2CAP_Control_Frame.classes.get(code)
|
||||
@@ -216,11 +216,11 @@ class L2CAP_Control_Frame:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def code_name(code):
|
||||
def code_name(code) -> str:
|
||||
return name_or_number(L2CAP_CONTROL_FRAME_NAMES, code)
|
||||
|
||||
@staticmethod
|
||||
def decode_configuration_options(data):
|
||||
def decode_configuration_options(data) -> List[Tuple[int, bytes]]:
|
||||
options = []
|
||||
while len(data) >= 2:
|
||||
value_type = data[0]
|
||||
@@ -232,7 +232,7 @@ class L2CAP_Control_Frame:
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def encode_configuration_options(options):
|
||||
def encode_configuration_options(options) -> bytes:
|
||||
return b''.join(
|
||||
[bytes([option[0], len(option[1])]) + option[1] for option in options]
|
||||
)
|
||||
@@ -256,29 +256,29 @@ class L2CAP_Control_Frame:
|
||||
|
||||
return inner
|
||||
|
||||
def __init__(self, pdu=None, **kwargs):
|
||||
def __init__(self, pdu=None, **kwargs) -> None:
|
||||
self.identifier = kwargs.get('identifier', 0)
|
||||
if hasattr(self, 'fields') and kwargs:
|
||||
HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||
if pdu is None:
|
||||
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
|
||||
pdu = (
|
||||
bytes([self.code, self.identifier])
|
||||
+ struct.pack('<H', len(data))
|
||||
+ data
|
||||
)
|
||||
if pdu is None:
|
||||
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
|
||||
pdu = (
|
||||
bytes([self.code, self.identifier])
|
||||
+ struct.pack('<H', len(data))
|
||||
+ data
|
||||
)
|
||||
self.pdu = pdu
|
||||
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
def to_bytes(self) -> bytes:
|
||||
return self.pdu
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
|
||||
if fields := getattr(self, 'fields', None):
|
||||
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ')
|
||||
@@ -315,7 +315,7 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reason_name(reason):
|
||||
def reason_name(reason) -> str:
|
||||
return name_or_number(L2CAP_Command_Reject.REASON_NAMES, reason)
|
||||
|
||||
|
||||
@@ -343,7 +343,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def parse_psm(data, offset=0):
|
||||
def parse_psm(data, offset=0) -> Tuple[int, int]:
|
||||
psm_length = 2
|
||||
psm = data[offset] | data[offset + 1] << 8
|
||||
|
||||
@@ -355,7 +355,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
|
||||
return offset + psm_length, psm
|
||||
|
||||
@staticmethod
|
||||
def serialize_psm(psm):
|
||||
def serialize_psm(psm) -> bytes:
|
||||
serialized = struct.pack('<H', psm & 0xFFFF)
|
||||
psm >>= 16
|
||||
while psm:
|
||||
@@ -405,7 +405,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def result_name(result):
|
||||
def result_name(result) -> str:
|
||||
return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result)
|
||||
|
||||
|
||||
@@ -452,7 +452,7 @@ class L2CAP_Configure_Response(L2CAP_Control_Frame):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def result_name(result):
|
||||
def result_name(result) -> str:
|
||||
return name_or_number(L2CAP_Configure_Response.RESULT_NAMES, result)
|
||||
|
||||
|
||||
@@ -529,7 +529,7 @@ class L2CAP_Information_Request(L2CAP_Control_Frame):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def info_type_name(info_type):
|
||||
def info_type_name(info_type) -> str:
|
||||
return name_or_number(L2CAP_Information_Request.INFO_TYPE_NAMES, info_type)
|
||||
|
||||
|
||||
@@ -556,7 +556,7 @@ class L2CAP_Information_Response(L2CAP_Control_Frame):
|
||||
RESULT_NAMES = {SUCCESS: 'SUCCESS', NOT_SUPPORTED: 'NOT_SUPPORTED'}
|
||||
|
||||
@staticmethod
|
||||
def result_name(result):
|
||||
def result_name(result) -> str:
|
||||
return name_or_number(L2CAP_Information_Response.RESULT_NAMES, result)
|
||||
|
||||
|
||||
@@ -588,6 +588,8 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
|
||||
(CODE 0x14)
|
||||
'''
|
||||
|
||||
source_cid: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@L2CAP_Control_Frame.subclass(
|
||||
@@ -640,7 +642,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def result_name(result):
|
||||
def result_name(result) -> str:
|
||||
return name_or_number(
|
||||
L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result
|
||||
)
|
||||
@@ -701,7 +703,14 @@ class Channel(EventEmitter):
|
||||
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
|
||||
}
|
||||
|
||||
def __init__(self, manager, connection, signaling_cid, psm, source_cid, mtu):
|
||||
connection_result: Optional[asyncio.Future[None]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
response: Optional[asyncio.Future[bytes]]
|
||||
sink: Optional[Callable[[bytes], Any]]
|
||||
|
||||
def __init__(
|
||||
self, manager, connection, signaling_cid, psm, source_cid, mtu
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
@@ -716,19 +725,19 @@ class Channel(EventEmitter):
|
||||
self.disconnection_result = None
|
||||
self.sink = None
|
||||
|
||||
def change_state(self, new_state):
|
||||
def change_state(self, new_state) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
|
||||
)
|
||||
self.state = new_state
|
||||
|
||||
def send_pdu(self, pdu):
|
||||
def send_pdu(self, pdu) -> None:
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame):
|
||||
def send_control_frame(self, frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
|
||||
|
||||
async def send_request(self, request):
|
||||
async def send_request(self, request) -> bytes:
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
raise InvalidStateError('request already pending')
|
||||
@@ -739,7 +748,7 @@ class Channel(EventEmitter):
|
||||
self.send_pdu(request)
|
||||
return await self.response
|
||||
|
||||
def on_pdu(self, pdu):
|
||||
def on_pdu(self, pdu) -> None:
|
||||
if self.response:
|
||||
self.response.set_result(pdu)
|
||||
self.response = None
|
||||
@@ -751,7 +760,7 @@ class Channel(EventEmitter):
|
||||
color('received pdu without a pending request or sink', 'red')
|
||||
)
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
if self.state != Channel.CLOSED:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
@@ -778,7 +787,7 @@ class Channel(EventEmitter):
|
||||
finally:
|
||||
self.connection_result = None
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
if self.state != Channel.OPEN:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
@@ -796,12 +805,12 @@ class Channel(EventEmitter):
|
||||
self.disconnection_result = asyncio.get_running_loop().create_future()
|
||||
return await self.disconnection_result
|
||||
|
||||
def abort(self):
|
||||
def abort(self) -> None:
|
||||
if self.state == self.OPEN:
|
||||
self.change_state(self.CLOSED)
|
||||
self.emit('close')
|
||||
|
||||
def send_configure_request(self):
|
||||
def send_configure_request(self) -> None:
|
||||
options = L2CAP_Control_Frame.encode_configuration_options(
|
||||
[
|
||||
(
|
||||
@@ -819,7 +828,7 @@ class Channel(EventEmitter):
|
||||
)
|
||||
)
|
||||
|
||||
def on_connection_request(self, request):
|
||||
def on_connection_request(self, request) -> None:
|
||||
self.destination_cid = request.source_cid
|
||||
self.change_state(Channel.WAIT_CONNECT)
|
||||
self.send_control_frame(
|
||||
@@ -858,7 +867,7 @@ class Channel(EventEmitter):
|
||||
)
|
||||
self.connection_result = None
|
||||
|
||||
def on_configure_request(self, request):
|
||||
def on_configure_request(self, request) -> None:
|
||||
if self.state not in (
|
||||
Channel.WAIT_CONFIG,
|
||||
Channel.WAIT_CONFIG_REQ,
|
||||
@@ -896,7 +905,7 @@ class Channel(EventEmitter):
|
||||
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
||||
|
||||
def on_configure_response(self, response):
|
||||
def on_configure_response(self, response) -> None:
|
||||
if response.result == L2CAP_Configure_Response.SUCCESS:
|
||||
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ)
|
||||
@@ -930,7 +939,7 @@ class Channel(EventEmitter):
|
||||
)
|
||||
# TODO: decide how to fail gracefully
|
||||
|
||||
def on_disconnection_request(self, request):
|
||||
def on_disconnection_request(self, request) -> None:
|
||||
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Response(
|
||||
@@ -945,7 +954,7 @@ class Channel(EventEmitter):
|
||||
else:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
|
||||
def on_disconnection_response(self, response):
|
||||
def on_disconnection_response(self, response) -> None:
|
||||
if self.state != Channel.WAIT_DISCONNECT:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
@@ -964,7 +973,7 @@ class Channel(EventEmitter):
|
||||
self.emit('close')
|
||||
self.manager.on_channel_closed(self)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Channel({self.source_cid}->{self.destination_cid}, '
|
||||
f'PSM={self.psm}, '
|
||||
@@ -995,8 +1004,13 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
CONNECTION_ERROR: 'CONNECTION_ERROR',
|
||||
}
|
||||
|
||||
out_queue: Deque[bytes]
|
||||
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
out_sdu: Optional[bytes]
|
||||
|
||||
@staticmethod
|
||||
def state_name(state):
|
||||
def state_name(state) -> str:
|
||||
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
|
||||
|
||||
def __init__(
|
||||
@@ -1013,7 +1027,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
peer_mps,
|
||||
peer_credits,
|
||||
connected,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
@@ -1045,7 +1059,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
else:
|
||||
self.state = LeConnectionOrientedChannel.INIT
|
||||
|
||||
def change_state(self, new_state):
|
||||
def change_state(self, new_state) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
|
||||
)
|
||||
@@ -1056,13 +1070,13 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
elif new_state == self.DISCONNECTED:
|
||||
self.emit('close')
|
||||
|
||||
def send_pdu(self, pdu):
|
||||
def send_pdu(self, pdu) -> None:
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame):
|
||||
def send_control_frame(self, frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> LeConnectionOrientedChannel:
|
||||
# Check that we're in the right state
|
||||
if self.state != self.INIT:
|
||||
raise InvalidStateError('not in a connectable state')
|
||||
@@ -1090,7 +1104,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
# Wait for the connection to succeed or fail
|
||||
return await self.connection_result
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
# Check that we're connected
|
||||
if self.state != self.CONNECTED:
|
||||
raise InvalidStateError('not connected')
|
||||
@@ -1110,11 +1124,11 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.disconnection_result = asyncio.get_running_loop().create_future()
|
||||
return await self.disconnection_result
|
||||
|
||||
def abort(self):
|
||||
def abort(self) -> None:
|
||||
if self.state == self.CONNECTED:
|
||||
self.change_state(self.DISCONNECTED)
|
||||
|
||||
def on_pdu(self, pdu):
|
||||
def on_pdu(self, pdu) -> None:
|
||||
if self.sink is None:
|
||||
logger.warning('received pdu without a sink')
|
||||
return
|
||||
@@ -1180,7 +1194,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.in_sdu = None
|
||||
self.in_sdu_length = 0
|
||||
|
||||
def on_connection_response(self, response):
|
||||
def on_connection_response(self, response) -> None:
|
||||
# Look for a matching pending response result
|
||||
if self.connection_result is None:
|
||||
logger.warning(
|
||||
@@ -1214,14 +1228,14 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
# Cleanup
|
||||
self.connection_result = None
|
||||
|
||||
def on_credits(self, credits): # pylint: disable=redefined-builtin
|
||||
def on_credits(self, credits) -> None: # pylint: disable=redefined-builtin
|
||||
self.credits += credits
|
||||
logger.debug(f'received {credits} credits, total = {self.credits}')
|
||||
|
||||
# Try to send more data if we have any queued up
|
||||
self.process_output()
|
||||
|
||||
def on_disconnection_request(self, request):
|
||||
def on_disconnection_request(self, request) -> None:
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Response(
|
||||
identifier=request.identifier,
|
||||
@@ -1232,7 +1246,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.change_state(self.DISCONNECTED)
|
||||
self.flush_output()
|
||||
|
||||
def on_disconnection_response(self, response):
|
||||
def on_disconnection_response(self, response) -> None:
|
||||
if self.state != self.DISCONNECTING:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
@@ -1249,11 +1263,11 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
|
||||
def flush_output(self):
|
||||
def flush_output(self) -> None:
|
||||
self.out_queue.clear()
|
||||
self.out_sdu = None
|
||||
|
||||
def process_output(self):
|
||||
def process_output(self) -> None:
|
||||
while self.credits > 0:
|
||||
if self.out_sdu is not None:
|
||||
# Finish the current SDU
|
||||
@@ -1296,7 +1310,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.drained.set()
|
||||
return
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data) -> None:
|
||||
if self.state != self.CONNECTED:
|
||||
logger.warning('not connected, dropping data')
|
||||
return
|
||||
@@ -1311,18 +1325,18 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
# Send what we can
|
||||
self.process_output()
|
||||
|
||||
async def drain(self):
|
||||
async def drain(self) -> None:
|
||||
await self.drained.wait()
|
||||
|
||||
def pause_reading(self):
|
||||
def pause_reading(self) -> None:
|
||||
# TODO: not implemented yet
|
||||
pass
|
||||
|
||||
def resume_reading(self):
|
||||
def resume_reading(self) -> None:
|
||||
# TODO: not implemented yet
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'CoC({self.source_cid}->{self.destination_cid}, '
|
||||
f'State={self.state_name(self.state)}, '
|
||||
@@ -1335,9 +1349,19 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ChannelManager:
|
||||
identifiers: Dict[int, int]
|
||||
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
|
||||
servers: Dict[int, Callable[[Channel], Any]]
|
||||
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
|
||||
le_coc_servers: Dict[
|
||||
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
|
||||
]
|
||||
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
|
||||
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
|
||||
|
||||
def __init__(
|
||||
self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU
|
||||
):
|
||||
) -> None:
|
||||
self._host = None
|
||||
self.identifiers = {} # Incrementing identifier values by connection
|
||||
self.channels = {} # All channels, mapped by connection and source cid
|
||||
@@ -1379,7 +1403,7 @@ class ChannelManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def find_free_br_edr_cid(channels):
|
||||
def find_free_br_edr_cid(channels) -> int:
|
||||
# Pick the smallest valid CID that's not already in the list
|
||||
# (not necessarily the most efficient algorithm, but the list of CID is
|
||||
# very small in practice)
|
||||
@@ -1392,7 +1416,7 @@ class ChannelManager:
|
||||
raise RuntimeError('no free CID available')
|
||||
|
||||
@staticmethod
|
||||
def find_free_le_cid(channels):
|
||||
def find_free_le_cid(channels) -> int:
|
||||
# Pick the smallest valid CID that's not already in the list
|
||||
# (not necessarily the most efficient algorithm, but the list of CID is
|
||||
# very small in practice)
|
||||
@@ -1405,7 +1429,7 @@ class ChannelManager:
|
||||
raise RuntimeError('no free CID')
|
||||
|
||||
@staticmethod
|
||||
def check_le_coc_parameters(max_credits, mtu, mps):
|
||||
def check_le_coc_parameters(max_credits, mtu, mps) -> None:
|
||||
if (
|
||||
max_credits < 1
|
||||
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
@@ -1419,19 +1443,19 @@ class ChannelManager:
|
||||
):
|
||||
raise ValueError('MPS out of range')
|
||||
|
||||
def next_identifier(self, connection):
|
||||
def next_identifier(self, connection) -> int:
|
||||
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
||||
self.identifiers[connection.handle] = identifier
|
||||
return identifier
|
||||
|
||||
def register_fixed_channel(self, cid, handler):
|
||||
def register_fixed_channel(self, cid, handler) -> None:
|
||||
self.fixed_channels[cid] = handler
|
||||
|
||||
def deregister_fixed_channel(self, cid):
|
||||
def deregister_fixed_channel(self, cid) -> None:
|
||||
if cid in self.fixed_channels:
|
||||
del self.fixed_channels[cid]
|
||||
|
||||
def register_server(self, psm, server):
|
||||
def register_server(self, psm, server) -> int:
|
||||
if psm == 0:
|
||||
# Find a free PSM
|
||||
for candidate in range(
|
||||
@@ -1470,7 +1494,7 @@ class ChannelManager:
|
||||
max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
):
|
||||
) -> int:
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
|
||||
if psm == 0:
|
||||
@@ -1498,7 +1522,7 @@ class ChannelManager:
|
||||
|
||||
return psm
|
||||
|
||||
def on_disconnection(self, connection_handle, _reason):
|
||||
def on_disconnection(self, connection_handle, _reason) -> None:
|
||||
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
|
||||
if connection_handle in self.channels:
|
||||
for _, channel in self.channels[connection_handle].items():
|
||||
@@ -1511,7 +1535,7 @@ class ChannelManager:
|
||||
if connection_handle in self.identifiers:
|
||||
del self.identifiers[connection_handle]
|
||||
|
||||
def send_pdu(self, connection, cid, pdu):
|
||||
def send_pdu(self, connection, cid, pdu) -> None:
|
||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||
@@ -1520,14 +1544,16 @@ class ChannelManager:
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
|
||||
|
||||
def on_pdu(self, connection, cid, pdu):
|
||||
def on_pdu(self, connection, cid, pdu) -> None:
|
||||
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
|
||||
# Parse the L2CAP payload into a Control Frame object
|
||||
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
|
||||
|
||||
self.on_control_frame(connection, cid, control_frame)
|
||||
elif cid in self.fixed_channels:
|
||||
self.fixed_channels[cid](connection.handle, pdu)
|
||||
handler = self.fixed_channels[cid]
|
||||
assert handler is not None
|
||||
handler(connection.handle, pdu)
|
||||
else:
|
||||
if (channel := self.find_channel(connection.handle, cid)) is None:
|
||||
logger.warning(
|
||||
@@ -1539,7 +1565,7 @@ class ChannelManager:
|
||||
|
||||
channel.on_pdu(pdu)
|
||||
|
||||
def send_control_frame(self, connection, cid, control_frame):
|
||||
def send_control_frame(self, connection, cid, control_frame) -> None:
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
|
||||
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
|
||||
@@ -1547,7 +1573,7 @@ class ChannelManager:
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
|
||||
|
||||
def on_control_frame(self, connection, cid, control_frame):
|
||||
def on_control_frame(self, connection, cid, control_frame) -> None:
|
||||
logger.debug(
|
||||
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
|
||||
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
|
||||
@@ -1584,10 +1610,10 @@ class ChannelManager:
|
||||
),
|
||||
)
|
||||
|
||||
def on_l2cap_command_reject(self, _connection, _cid, packet):
|
||||
def on_l2cap_command_reject(self, _connection, _cid, packet) -> None:
|
||||
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
|
||||
|
||||
def on_l2cap_connection_request(self, connection, cid, request):
|
||||
def on_l2cap_connection_request(self, connection, cid, request) -> None:
|
||||
# Check if there's a server for this PSM
|
||||
server = self.servers.get(request.psm)
|
||||
if server:
|
||||
@@ -1639,7 +1665,7 @@ class ChannelManager:
|
||||
),
|
||||
)
|
||||
|
||||
def on_l2cap_connection_response(self, connection, cid, response):
|
||||
def on_l2cap_connection_response(self, connection, cid, response) -> None:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, response.source_cid)
|
||||
) is None:
|
||||
@@ -1654,7 +1680,7 @@ class ChannelManager:
|
||||
|
||||
channel.on_connection_response(response)
|
||||
|
||||
def on_l2cap_configure_request(self, connection, cid, request):
|
||||
def on_l2cap_configure_request(self, connection, cid, request) -> None:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, request.destination_cid)
|
||||
) is None:
|
||||
@@ -1669,7 +1695,7 @@ class ChannelManager:
|
||||
|
||||
channel.on_configure_request(request)
|
||||
|
||||
def on_l2cap_configure_response(self, connection, cid, response):
|
||||
def on_l2cap_configure_response(self, connection, cid, response) -> None:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, response.source_cid)
|
||||
) is None:
|
||||
@@ -1684,7 +1710,7 @@ class ChannelManager:
|
||||
|
||||
channel.on_configure_response(response)
|
||||
|
||||
def on_l2cap_disconnection_request(self, connection, cid, request):
|
||||
def on_l2cap_disconnection_request(self, connection, cid, request) -> None:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, request.destination_cid)
|
||||
) is None:
|
||||
@@ -1699,7 +1725,7 @@ class ChannelManager:
|
||||
|
||||
channel.on_disconnection_request(request)
|
||||
|
||||
def on_l2cap_disconnection_response(self, connection, cid, response):
|
||||
def on_l2cap_disconnection_response(self, connection, cid, response) -> None:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, response.source_cid)
|
||||
) is None:
|
||||
@@ -1714,7 +1740,7 @@ class ChannelManager:
|
||||
|
||||
channel.on_disconnection_response(response)
|
||||
|
||||
def on_l2cap_echo_request(self, connection, cid, request):
|
||||
def on_l2cap_echo_request(self, connection, cid, request) -> None:
|
||||
logger.debug(f'<<< Echo request: data={request.data.hex()}')
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
@@ -1722,11 +1748,11 @@ class ChannelManager:
|
||||
L2CAP_Echo_Response(identifier=request.identifier, data=request.data),
|
||||
)
|
||||
|
||||
def on_l2cap_echo_response(self, _connection, _cid, response):
|
||||
def on_l2cap_echo_response(self, _connection, _cid, response) -> None:
|
||||
logger.debug(f'<<< Echo response: data={response.data.hex()}')
|
||||
# TODO notify listeners
|
||||
|
||||
def on_l2cap_information_request(self, connection, cid, request):
|
||||
def on_l2cap_information_request(self, connection, cid, request) -> None:
|
||||
if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU:
|
||||
result = L2CAP_Information_Response.SUCCESS
|
||||
data = self.connectionless_mtu.to_bytes(2, 'little')
|
||||
@@ -1781,11 +1807,15 @@ class ChannelManager:
|
||||
),
|
||||
)
|
||||
|
||||
def on_l2cap_connection_parameter_update_response(self, connection, cid, response):
|
||||
def on_l2cap_connection_parameter_update_response(
|
||||
self, connection, cid, response
|
||||
) -> None:
|
||||
# TODO: check response
|
||||
pass
|
||||
|
||||
def on_l2cap_le_credit_based_connection_request(self, connection, cid, request):
|
||||
def on_l2cap_le_credit_based_connection_request(
|
||||
self, connection, cid, request
|
||||
) -> None:
|
||||
if request.le_psm in self.le_coc_servers:
|
||||
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
|
||||
|
||||
@@ -1887,7 +1917,9 @@ class ChannelManager:
|
||||
),
|
||||
)
|
||||
|
||||
def on_l2cap_le_credit_based_connection_response(self, connection, _cid, response):
|
||||
def on_l2cap_le_credit_based_connection_response(
|
||||
self, connection, _cid, response
|
||||
) -> None:
|
||||
# Find the pending request by identifier
|
||||
request = self.le_coc_requests.get(response.identifier)
|
||||
if request is None:
|
||||
@@ -1910,7 +1942,7 @@ class ChannelManager:
|
||||
# Process the response
|
||||
channel.on_connection_response(response)
|
||||
|
||||
def on_l2cap_le_flow_control_credit(self, connection, _cid, credit):
|
||||
def on_l2cap_le_flow_control_credit(self, connection, _cid, credit) -> None:
|
||||
channel = self.find_le_coc_channel(connection.handle, credit.cid)
|
||||
if channel is None:
|
||||
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
|
||||
@@ -1918,13 +1950,15 @@ class ChannelManager:
|
||||
|
||||
channel.on_credits(credit.credits)
|
||||
|
||||
def on_channel_closed(self, channel):
|
||||
def on_channel_closed(self, channel) -> None:
|
||||
connection_channels = self.channels.get(channel.connection.handle)
|
||||
if connection_channels:
|
||||
if channel.source_cid in connection_channels:
|
||||
del connection_channels[channel.source_cid]
|
||||
|
||||
async def open_le_coc(self, connection, psm, max_credits, mtu, mps):
|
||||
async def open_le_coc(
|
||||
self, connection, psm, max_credits, mtu, mps
|
||||
) -> LeConnectionOrientedChannel:
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
|
||||
# Find a free CID for the new channel
|
||||
@@ -1965,7 +1999,7 @@ class ChannelManager:
|
||||
|
||||
return channel
|
||||
|
||||
async def connect(self, connection, psm):
|
||||
async def connect(self, connection, psm) -> Channel:
|
||||
# NOTE: this implementation hard-codes BR/EDR
|
||||
|
||||
# Find a free CID for a new channel
|
||||
|
||||
Reference in New Issue
Block a user