From 4d74339c04d540045dec1f5f16e074771b8f36ba Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Mon, 5 Jun 2023 16:28:52 +0800 Subject: [PATCH] Add typing for RFCOMM --- bumble/rfcomm.py | 177 +++++++++++++++++++++++++++++------------------ 1 file changed, 110 insertions(+), 67 deletions(-) diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 71be8dc9..0176a78a 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -19,8 +19,9 @@ import logging import asyncio from pyee import EventEmitter +from typing import Optional, Tuple, Callable, Dict, Union -from . import core +from . import core, l2cap from .colors import color from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError @@ -105,7 +106,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 # ----------------------------------------------------------------------------- -def compute_fcs(buffer): +def compute_fcs(buffer: bytes) -> int: result = 0xFF for byte in buffer: result = CRC_TABLE[result ^ byte] @@ -114,7 +115,15 @@ def compute_fcs(buffer): # ----------------------------------------------------------------------------- class RFCOMM_Frame: - def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False): + def __init__( + self, + frame_type: int, + c_r: int, + dlci: int, + p_f: int, + information: bytes = b'', + with_credits: bool = False, + ) -> None: self.type = frame_type self.c_r = c_r self.dlci = dlci @@ -136,11 +145,11 @@ class RFCOMM_Frame: else: self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length) - def type_name(self): + def type_name(self) -> str: return RFCOMM_FRAME_TYPE_NAMES[self.type] @staticmethod - def parse_mcc(data): + def parse_mcc(data) -> Tuple[int, int, bytes]: mcc_type = data[0] >> 2 c_r = (data[0] >> 1) & 1 length = data[1] @@ -154,36 +163,36 @@ class RFCOMM_Frame: return (mcc_type, c_r, value) @staticmethod - def make_mcc(mcc_type, c_r, data): + def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes: return ( bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data ) @staticmethod - def sabm(c_r, dlci): + def sabm(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1) @staticmethod - def ua(c_r, dlci): + def ua(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1) @staticmethod - def dm(c_r, dlci): + def dm(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1) @staticmethod - def disc(c_r, dlci): + def disc(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1) @staticmethod - def uih(c_r, dlci, information, p_f=0): + def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0): return RFCOMM_Frame( RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1) ) @staticmethod - def from_bytes(data): + def from_bytes(data: bytes): # Extract fields dlci = (data[0] >> 2) & 0x3F c_r = (data[0] >> 1) & 0x01 @@ -227,15 +236,23 @@ class RFCOMM_Frame: # ----------------------------------------------------------------------------- class RFCOMM_MCC_PN: + dlci: int + cl: int + priority: int + ack_timer: int + max_frame_size: int + max_retransmissions: int + window_size: int + def __init__( self, - dlci, - cl, - priority, - ack_timer, - max_frame_size, - max_retransmissions, - window_size, + dlci: int, + cl: int, + priority: int, + ack_timer: int, + max_frame_size: int, + max_retransmissions: int, + window_size: int, ): self.dlci = dlci self.cl = cl @@ -246,7 +263,7 @@ class RFCOMM_MCC_PN: self.window_size = window_size @staticmethod - def from_bytes(data): + def from_bytes(data: bytes): return RFCOMM_MCC_PN( dlci=data[0], cl=data[1], @@ -285,7 +302,14 @@ class RFCOMM_MCC_PN: # ----------------------------------------------------------------------------- class RFCOMM_MCC_MSC: - def __init__(self, dlci, fc, rtc, rtr, ic, dv): + dlci: int + fc: int + rtc: int + rtr: int + ic: int + dv: int + + def __init__(self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int): self.dlci = dlci self.fc = fc self.rtc = rtc @@ -294,7 +318,7 @@ class RFCOMM_MCC_MSC: self.dv = dv @staticmethod - def from_bytes(data): + def from_bytes(data: bytes): return RFCOMM_MCC_MSC( dlci=data[0] >> 2, fc=data[1] >> 1 & 1, @@ -347,7 +371,12 @@ class DLC(EventEmitter): RESET: 'RESET', } - def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits): + connection_result: Optional[asyncio.Future] + sink: Optional[Callable[[bytes], None]] + + def __init__( + self, multiplexer, dlci: int, max_frame_size: int, initial_tx_credits: int + ): super().__init__() self.multiplexer = multiplexer self.dlci = dlci @@ -368,23 +397,23 @@ class DLC(EventEmitter): ) @staticmethod - def state_name(state): + def state_name(state: int) -> str: return DLC.STATE_NAMES[state] - def change_state(self, new_state): + def change_state(self, new_state: int) -> None: logger.debug( f'{self} state change -> {color(self.state_name(new_state), "magenta")}' ) self.state = new_state - def send_frame(self, frame): + def send_frame(self, frame: RFCOMM_Frame) -> None: self.multiplexer.send_frame(frame) - def on_frame(self, frame): + def on_frame(self, frame: RFCOMM_Frame) -> None: handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler(frame) - def on_sabm_frame(self, _frame): + def on_sabm_frame(self, _frame) -> None: if self.state != DLC.CONNECTING: logger.warning( color('!!! received SABM when not in CONNECTING state', 'red') @@ -404,7 +433,7 @@ class DLC(EventEmitter): self.change_state(DLC.CONNECTED) self.emit('open') - def on_ua_frame(self, _frame): + def on_ua_frame(self, _frame) -> None: if self.state != DLC.CONNECTING: logger.warning( color('!!! received SABM when not in CONNECTING state', 'red') @@ -422,15 +451,15 @@ class DLC(EventEmitter): self.change_state(DLC.CONNECTED) self.multiplexer.on_dlc_open_complete(self) - def on_dm_frame(self, frame): + def on_dm_frame(self, frame) -> None: # TODO: handle all states pass - def on_disc_frame(self, _frame): + def on_disc_frame(self, _frame) -> None: # TODO: handle all states self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci)) - def on_uih_frame(self, frame): + def on_uih_frame(self, frame: RFCOMM_Frame) -> None: data = frame.information if frame.p_f == 1: # With credits @@ -460,10 +489,10 @@ class DLC(EventEmitter): # Check if there's anything to send (including credits) self.process_tx() - def on_ui_frame(self, frame): + def on_ui_frame(self, frame) -> None: pass - def on_mcc_msc(self, c_r, msc): + def on_mcc_msc(self, c_r, msc) -> None: if c_r: # Command logger.debug(f'<<< MCC MSC Command: {msc}') @@ -477,7 +506,7 @@ class DLC(EventEmitter): # Response logger.debug(f'<<< MCC MSC Response: {msc}') - def connect(self): + def connect(self) -> None: if self.state != DLC.INIT: raise InvalidStateError('invalid state') @@ -485,7 +514,7 @@ class DLC(EventEmitter): self.connection_result = asyncio.get_running_loop().create_future() self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci)) - def accept(self): + def accept(self) -> None: if self.state != DLC.INIT: raise InvalidStateError('invalid state') @@ -503,13 +532,13 @@ class DLC(EventEmitter): self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.change_state(DLC.CONNECTING) - def rx_credits_needed(self): + def rx_credits_needed(self) -> int: if self.rx_credits <= self.rx_threshold: return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits return 0 - def process_tx(self): + def process_tx(self) -> None: # Send anything we can (or an empty frame if we need to send rx credits) rx_credits_needed = self.rx_credits_needed() while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0: @@ -547,7 +576,7 @@ class DLC(EventEmitter): rx_credits_needed = 0 # Stream protocol - def write(self, data): + def write(self, data: Union[bytes, str]) -> None: # We can only send bytes if not isinstance(data, bytes): if isinstance(data, str): @@ -559,7 +588,7 @@ class DLC(EventEmitter): self.tx_buffer += data self.process_tx() - def drain(self): + def drain(self) -> None: # TODO pass @@ -592,7 +621,13 @@ class Multiplexer(EventEmitter): RESET: 'RESET', } - def __init__(self, l2cap_channel, role): + connection_result: Optional[asyncio.Future] + disconnection_result: Optional[asyncio.Future] + open_result: Optional[asyncio.Future] + acceptor: Optional[Callable[[int], bool]] + dlcs: Dict[int, DLC] + + def __init__(self, l2cap_channel: l2cap.Channel, role: int) -> None: super().__init__() self.role = role self.l2cap_channel = l2cap_channel @@ -607,20 +642,20 @@ class Multiplexer(EventEmitter): l2cap_channel.sink = self.on_pdu @staticmethod - def state_name(state): + def state_name(state: int): return Multiplexer.STATE_NAMES[state] - def change_state(self, new_state): + def change_state(self, new_state: int) -> None: logger.debug( f'{self} state change -> {color(self.state_name(new_state), "cyan")}' ) self.state = new_state - def send_frame(self, frame): + def send_frame(self, frame: RFCOMM_Frame) -> None: logger.debug(f'>>> Multiplexer sending {frame}') self.l2cap_channel.send_pdu(frame) - def on_pdu(self, pdu): + def on_pdu(self, pdu: bytes) -> None: frame = RFCOMM_Frame.from_bytes(pdu) logger.debug(f'<<< Multiplexer received {frame}') @@ -640,18 +675,18 @@ class Multiplexer(EventEmitter): return dlc.on_frame(frame) - def on_frame(self, frame): + def on_frame(self, frame: RFCOMM_Frame) -> None: handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler(frame) - def on_sabm_frame(self, _frame): + def on_sabm_frame(self, _frame) -> None: if self.state != Multiplexer.INIT: logger.debug('not in INIT state, ignoring SABM') return self.change_state(Multiplexer.CONNECTED) self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0)) - def on_ua_frame(self, _frame): + def on_ua_frame(self, _frame) -> None: if self.state == Multiplexer.CONNECTING: self.change_state(Multiplexer.CONNECTED) if self.connection_result: @@ -663,7 +698,7 @@ class Multiplexer(EventEmitter): self.disconnection_result.set_result(None) self.disconnection_result = None - def on_dm_frame(self, _frame): + def on_dm_frame(self, _frame) -> None: if self.state == Multiplexer.OPENING: self.change_state(Multiplexer.CONNECTED) if self.open_result: @@ -678,13 +713,13 @@ class Multiplexer(EventEmitter): else: logger.warning(f'unexpected state for DM: {self}') - def on_disc_frame(self, _frame): + def on_disc_frame(self, _frame) -> None: self.change_state(Multiplexer.DISCONNECTED) self.send_frame( RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0) ) - def on_uih_frame(self, frame): + def on_uih_frame(self, frame: RFCOMM_Frame) -> None: (mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information) if mcc_type == RFCOMM_MCC_PN_TYPE: @@ -694,10 +729,10 @@ class Multiplexer(EventEmitter): mcs = RFCOMM_MCC_MSC.from_bytes(value) self.on_mcc_msc(c_r, mcs) - def on_ui_frame(self, frame): + def on_ui_frame(self, frame) -> None: pass - def on_mcc_pn(self, c_r, pn): + def on_mcc_pn(self, c_r, pn) -> None: if c_r == 1: # Command logger.debug(f'<<< PN Command: {pn}') @@ -736,14 +771,14 @@ class Multiplexer(EventEmitter): else: logger.warning('ignoring PN response') - def on_mcc_msc(self, c_r, msc): + def on_mcc_msc(self, c_r, msc) -> None: dlc = self.dlcs.get(msc.dlci) if dlc is None: logger.warning(f'no dlc for DLCI {msc.dlci}') return dlc.on_mcc_msc(c_r, msc) - async def connect(self): + async def connect(self) -> None: if self.state != Multiplexer.INIT: raise InvalidStateError('invalid state') @@ -752,7 +787,7 @@ class Multiplexer(EventEmitter): self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0)) return await self.connection_result - async def disconnect(self): + async def disconnect(self) -> None: if self.state != Multiplexer.CONNECTED: return @@ -765,7 +800,7 @@ class Multiplexer(EventEmitter): ) await self.disconnection_result - async def open_dlc(self, channel): + async def open_dlc(self, channel: int) -> DLC: if self.state != Multiplexer.CONNECTED: if self.state == Multiplexer.OPENING: raise InvalidStateError('open already in progress') @@ -796,7 +831,7 @@ class Multiplexer(EventEmitter): self.open_result = None return result - def on_dlc_open_complete(self, dlc): + def on_dlc_open_complete(self, dlc: DLC): logger.debug(f'DLC [{dlc.dlci}] open complete') self.change_state(Multiplexer.CONNECTED) if self.open_result: @@ -808,13 +843,16 @@ class Multiplexer(EventEmitter): # ----------------------------------------------------------------------------- class Client: - def __init__(self, device, connection): + multiplexer: Optional[Multiplexer] + l2cap_channel: Optional[l2cap.Channel] + + def __init__(self, device, connection) -> None: self.device = device self.connection = connection self.l2cap_channel = None self.multiplexer = None - async def start(self): + async def start(self) -> Multiplexer: # Create a new L2CAP connection try: self.l2cap_channel = await self.device.l2cap_channel_manager.connect( @@ -824,6 +862,7 @@ class Client: logger.warning(f'L2CAP connection failed: {error}') raise + assert self.l2cap_channel is not None # Create a mutliplexer to manage DLCs with the server self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.INITIATOR) @@ -832,7 +871,9 @@ class Client: return self.multiplexer - async def shutdown(self): + async def shutdown(self) -> None: + if self.multiplexer is None: + return # Disconnect the multiplexer await self.multiplexer.disconnect() self.multiplexer = None @@ -843,7 +884,9 @@ class Client: # ----------------------------------------------------------------------------- class Server(EventEmitter): - def __init__(self, device): + acceptors: Dict[int, Callable[[DLC], None]] + + def __init__(self, device) -> None: super().__init__() self.device = device self.multiplexer = None @@ -852,7 +895,7 @@ class Server(EventEmitter): # Register ourselves with the L2CAP channel manager device.register_l2cap_server(RFCOMM_PSM, self.on_connection) - def listen(self, acceptor, channel=0): + def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int: if channel: if channel in self.acceptors: # Busy @@ -874,11 +917,11 @@ class Server(EventEmitter): self.acceptors[channel] = acceptor return channel - def on_connection(self, l2cap_channel): + def on_connection(self, l2cap_channel: l2cap.Channel) -> None: logger.debug(f'+++ new L2CAP connection: {l2cap_channel}') l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) - def on_l2cap_channel_open(self, l2cap_channel): + def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') # Create a new multiplexer for the channel @@ -889,10 +932,10 @@ class Server(EventEmitter): # Notify self.emit('start', multiplexer) - def accept_dlc(self, channel_number): + def accept_dlc(self, channel_number: int) -> bool: return channel_number in self.acceptors - def on_dlc(self, dlc): + def on_dlc(self, dlc: DLC) -> None: logger.debug(f'@@@ new DLC connected: {dlc}') # Let the acceptor know