From e732f2589f173323fa9f9f8c3a3fdaa80f49526b Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 4 Oct 2023 16:56:16 +0800 Subject: [PATCH] Refactor L2CAP API --- bumble/avdtp.py | 53 ++++++--- bumble/device.py | 101 ++++++++++++++-- bumble/l2cap.py | 295 +++++++++++++++++++++++++++++++++++------------ bumble/rfcomm.py | 8 +- bumble/sdp.py | 4 +- bumble/utils.py | 17 +++ 6 files changed, 374 insertions(+), 104 deletions(-) diff --git a/bumble/avdtp.py b/bumble/avdtp.py index bda168a..9a332f4 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -21,6 +21,7 @@ import struct import time import logging import enum +import warnings from pyee import EventEmitter from typing import ( Any, @@ -368,7 +369,7 @@ class MediaPacketPump: self.clock = clock self.pump_task = None - async def start(self, rtp_channel: l2cap.Channel) -> None: + async def start(self, rtp_channel: l2cap.ClassicChannel) -> None: async def pump_packets(): start_time = 0 start_timestamp = 0 @@ -1254,7 +1255,7 @@ class Protocol(EventEmitter): remote_endpoints: Dict[int, DiscoveredStreamEndPoint] streams: Dict[int, Stream] transaction_results: List[Optional[asyncio.Future[Message]]] - channel_connector: Callable[[], Awaitable[l2cap.Channel]] + channel_connector: Callable[[], Awaitable[l2cap.ClassicChannel]] class PacketType(enum.IntEnum): SINGLE_PACKET = 0 @@ -1262,19 +1263,23 @@ class Protocol(EventEmitter): CONTINUE_PACKET = 2 END_PACKET = 3 + @staticmethod + def packet_type_name(packet_type): + return name_or_number(Protocol.PACKET_TYPE_NAMES, packet_type) + @staticmethod async def connect( connection: device.Connection, version: Tuple[int, int] = (1, 3) ) -> Protocol: - connector = connection.create_l2cap_connector(AVDTP_PSM) - channel = await connector() + channel = await connection.create_l2cap_channel( + spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM) + ) protocol = Protocol(channel, version) - protocol.channel_connector = connector return protocol def __init__( - self, l2cap_channel: l2cap.Channel, version: Tuple[int, int] = (1, 3) + self, l2cap_channel: l2cap.ClassicChannel, version: Tuple[int, int] = (1, 3) ) -> None: super().__init__() self.l2cap_channel = l2cap_channel @@ -1712,8 +1717,13 @@ class Listener(EventEmitter): servers: Dict[int, Protocol] @staticmethod - def create_registrar(device): - return device.create_l2cap_registrar(AVDTP_PSM) + def create_registrar(device: device.Device): + warnings.warn("Please use Listener.for_device()", DeprecationWarning) + + def wrapper(handler: Callable[[l2cap.ClassicChannel], None]) -> None: + device.create_l2cap_server(l2cap.ClassicChannelSpec(psm=AVDTP_PSM), handler) + + return wrapper def set_server(self, connection: device.Connection, server: Protocol) -> None: self.servers[connection.handle] = server @@ -1722,15 +1732,28 @@ class Listener(EventEmitter): if connection.handle in self.servers: del self.servers[connection.handle] - def __init__(self, registrar, version=(1, 3)): + def __init__(self, registrar=None, version=(1, 3)): super().__init__() self.version = version self.servers = {} # Servers, by connection handle # Listen for incoming L2CAP connections - registrar(self.on_l2cap_connection) + if registrar: + warnings.warn("Please use Listener.for_device()", DeprecationWarning) + registrar(self.on_l2cap_connection) - def on_l2cap_connection(self, channel: l2cap.Channel) -> None: + @classmethod + def for_device( + cls, device: device.Device, version: Tuple[int, int] = (1, 3) + ) -> Listener: + listener = Listener(registrar=None, version=version) + l2cap_server = device.create_l2cap_server( + spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM) + ) + l2cap_server.on('connection', listener.on_l2cap_connection) + return listener + + def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None: logger.debug(f'{color("<<< incoming L2CAP connection:", "magenta")} {channel}') if channel.connection.handle in self.servers: @@ -1759,7 +1782,7 @@ class Stream: Pair of a local and a remote stream endpoint that can stream from one to the other ''' - rtp_channel: Optional[l2cap.Channel] + rtp_channel: Optional[l2cap.ClassicChannel] @staticmethod def state_name(state: int) -> str: @@ -1792,7 +1815,11 @@ class Stream: self.change_state(AVDTP_OPEN_STATE) # Create a channel for RTP packets - self.rtp_channel = await self.protocol.channel_connector() + self.rtp_channel = ( + await self.protocol.l2cap_channel.connection.create_l2cap_channel( + l2cap.ClassicChannelSpec(psm=AVDTP_PSM) + ) + ) async def start(self) -> None: # Auto-open if needed diff --git a/bumble/device.py b/bumble/device.py index 056da72..6e2ac3b 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -33,6 +33,8 @@ from typing import ( Tuple, Type, Union, + cast, + overload, TYPE_CHECKING, ) @@ -151,6 +153,7 @@ from .utils import ( CompositeEventEmitter, setup_event_forwarding, composite_listener, + deprecated, ) from .keys import ( KeyStore, @@ -670,9 +673,7 @@ class Connection(CompositeEventEmitter): def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None: self.device.send_l2cap_pdu(self.handle, cid, pdu) - def create_l2cap_connector(self, psm): - return self.device.create_l2cap_connector(self, psm) - + @deprecated("Please use create_l2cap_channel()") async def open_l2cap_channel( self, psm, @@ -682,6 +683,23 @@ class Connection(CompositeEventEmitter): ): return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps) + @overload + async def create_l2cap_channel( + self, spec: l2cap.ClassicChannelSpec + ) -> l2cap.ClassicChannel: + ... + + @overload + async def create_l2cap_channel( + self, spec: l2cap.LeCreditBasedChannelSpec + ) -> l2cap.LeCreditBasedChannel: + ... + + async def create_l2cap_channel( + self, spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec] + ) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]: + return await self.device.create_l2cap_channel(connection=self, spec=spec) + async def disconnect( self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR ) -> None: @@ -1180,15 +1198,11 @@ class Device(CompositeEventEmitter): return None - def create_l2cap_connector(self, connection, psm): - return lambda: self.l2cap_channel_manager.connect(connection, psm) - - def create_l2cap_registrar(self, psm): - return lambda handler: self.register_l2cap_server(psm, handler) - + @deprecated("Please use create_l2cap_server()") def register_l2cap_server(self, psm, server) -> int: return self.l2cap_channel_manager.register_server(psm, server) + @deprecated("Please use create_l2cap_server()") def register_l2cap_channel_server( self, psm, @@ -1201,6 +1215,7 @@ class Device(CompositeEventEmitter): psm, server, max_credits, mtu, mps ) + @deprecated("Please use create_l2cap_channel()") async def open_l2cap_channel( self, connection, @@ -1213,6 +1228,74 @@ class Device(CompositeEventEmitter): connection, psm, max_credits, mtu, mps ) + @overload + async def create_l2cap_channel( + self, + connection: Connection, + spec: l2cap.ClassicChannelSpec, + ) -> l2cap.ClassicChannel: + ... + + @overload + async def create_l2cap_channel( + self, + connection: Connection, + spec: l2cap.LeCreditBasedChannelSpec, + ) -> l2cap.LeCreditBasedChannel: + ... + + async def create_l2cap_channel( + self, + connection: Connection, + spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec], + ) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]: + if isinstance(spec, l2cap.ClassicChannelSpec): + return await self.l2cap_channel_manager.create_classic_channel( + connection=connection, spec=spec + ) + if isinstance(spec, l2cap.LeCreditBasedChannelSpec): + return await self.l2cap_channel_manager.create_le_credit_based_channel( + connection=connection, spec=spec + ) + + @overload + def create_l2cap_server( + self, + spec: l2cap.ClassicChannelSpec, + handler: Optional[Callable[[l2cap.ClassicChannel], Any]] = None, + ) -> l2cap.ClassicChannelServer: + ... + + @overload + def create_l2cap_server( + self, + spec: l2cap.LeCreditBasedChannelSpec, + handler: Optional[Callable[[l2cap.LeCreditBasedChannel], Any]] = None, + ) -> l2cap.LeCreditBasedChannelServer: + ... + + def create_l2cap_server( + self, + spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec], + handler: Union[ + Callable[[l2cap.ClassicChannel], Any], + Callable[[l2cap.LeCreditBasedChannel], Any], + None, + ] = None, + ) -> Union[l2cap.ClassicChannelServer, l2cap.LeCreditBasedChannelServer]: + if isinstance(spec, l2cap.ClassicChannelSpec): + return self.l2cap_channel_manager.create_classic_server( + spec=spec, + handler=cast(Callable[[l2cap.ClassicChannel], Any], handler), + ) + elif isinstance(spec, l2cap.LeCreditBasedChannelSpec): + return self.l2cap_channel_manager.create_le_credit_based_server( + handler=cast(Callable[[l2cap.LeCreditBasedChannel], Any], handler), + spec=spec, + ) + else: + raise ValueError(f'Unexpected mode {spec}') + def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: self.host.send_l2cap_pdu(connection_handle, cid, pdu) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index cccb172..749f0d3 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import dataclasses import enum import logging import struct @@ -38,6 +39,7 @@ from typing import ( TYPE_CHECKING, ) +from .utils import deprecated from .colors import color from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError from .hci import ( @@ -167,6 +169,34 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01 # pylint: disable=invalid-name +@dataclasses.dataclass +class ClassicChannelSpec: + psm: Optional[int] = None + mtu: int = L2CAP_MIN_BR_EDR_MTU + + +@dataclasses.dataclass +class LeCreditBasedChannelSpec: + psm: Optional[int] = None + mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU + mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS + + def __post_init__(self): + if ( + self.max_credits < 1 + or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS + ): + raise ValueError('max credits out of range') + if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU: + raise ValueError('MTU too small') + if ( + self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS + or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS + ): + raise ValueError('MPS out of range') + + class L2CAP_PDU: ''' See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT @@ -676,7 +706,7 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame): # ----------------------------------------------------------------------------- -class Channel(EventEmitter): +class ClassicChannel(EventEmitter): class State(enum.IntEnum): # States CLOSED = 0x00 @@ -990,7 +1020,7 @@ class Channel(EventEmitter): # ----------------------------------------------------------------------------- -class LeConnectionOrientedChannel(EventEmitter): +class LeCreditBasedChannel(EventEmitter): """ LE Credit-based Connection Oriented Channel """ @@ -1004,7 +1034,7 @@ class LeConnectionOrientedChannel(EventEmitter): CONNECTION_ERROR = 5 out_queue: Deque[bytes] - connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] + connection_result: Optional[asyncio.Future[LeCreditBasedChannel]] disconnection_result: Optional[asyncio.Future[None]] out_sdu: Optional[bytes] state: State @@ -1071,7 +1101,7 @@ class LeConnectionOrientedChannel(EventEmitter): def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) - async def connect(self) -> LeConnectionOrientedChannel: + async def connect(self) -> LeCreditBasedChannel: # Check that we're in the right state if self.state != self.State.INIT: raise InvalidStateError('not in a connectable state') @@ -1342,15 +1372,67 @@ class LeConnectionOrientedChannel(EventEmitter): ) +# ----------------------------------------------------------------------------- +class ClassicChannelServer(EventEmitter): + def __init__( + self, + manager: ChannelManager, + psm: int, + handler: Optional[Callable[[ClassicChannel], Any]], + mtu: int, + ) -> None: + super().__init__() + self.manager = manager + self.handler = handler + self.psm = psm + self.mtu = mtu + + def on_connection(self, channel: ClassicChannel) -> None: + self.emit('connection', channel) + if self.handler: + self.handler(channel) + + def close(self) -> None: + if self.psm in self.manager.servers: + del self.manager.servers[self.psm] + + +# ----------------------------------------------------------------------------- +class LeCreditBasedChannelServer(EventEmitter): + def __init__( + self, + manager: ChannelManager, + psm: int, + handler: Optional[Callable[[LeCreditBasedChannel], Any]], + max_credits: int, + mtu: int, + mps: int, + ) -> None: + super().__init__() + self.manager = manager + self.handler = handler + self.psm = psm + self.max_credits = max_credits + self.mtu = mtu + self.mps = mps + + def on_connection(self, channel: LeCreditBasedChannel) -> None: + self.emit('connection', channel) + if self.handler: + self.handler(channel) + + def close(self) -> None: + if self.psm in self.manager.le_coc_servers: + del self.manager.le_coc_servers[self.psm] + + # ----------------------------------------------------------------------------- class ChannelManager: identifiers: Dict[int, int] - channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]] - servers: Dict[int, Callable[[Channel], Any]] - le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]] - le_coc_servers: Dict[ - int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int] - ] + channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]] + servers: Dict[int, ClassicChannelServer] + le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]] + le_coc_servers: Dict[int, LeCreditBasedChannelServer] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] _host: Optional[Host] @@ -1429,21 +1511,6 @@ class ChannelManager: raise RuntimeError('no free CID') - @staticmethod - def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None: - if ( - max_credits < 1 - or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS - ): - raise ValueError('max credits out of range') - if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU: - raise ValueError('MTU too small') - if ( - mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS - or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS - ): - raise ValueError('MPS out of range') - def next_identifier(self, connection: Connection) -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 self.identifiers[connection.handle] = identifier @@ -1458,8 +1525,22 @@ class ChannelManager: if cid in self.fixed_channels: del self.fixed_channels[cid] - def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int: - if psm == 0: + @deprecated("Please use create_classic_channel_server") + def register_server( + self, + psm: int, + server: Callable[[ClassicChannel], Any], + ) -> int: + return self.create_classic_server( + handler=server, spec=ClassicChannelSpec(psm=psm) + ).psm + + def create_classic_server( + self, + spec: ClassicChannelSpec, + handler: Optional[Callable[[ClassicChannel], Any]] = None, + ) -> ClassicChannelServer: + if spec.psm is None: # Find a free PSM for candidate in range( L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 @@ -1468,62 +1549,75 @@ class ChannelManager: continue if candidate in self.servers: continue - psm = candidate + spec.psm = candidate break else: raise InvalidStateError('no free PSM') else: # Check that the PSM isn't already in use - if psm in self.servers: + if spec.psm in self.servers: raise ValueError('PSM already in use') # Check that the PSM is valid - if psm % 2 == 0: + if spec.psm % 2 == 0: raise ValueError('invalid PSM (not odd)') - check = psm >> 8 + check = spec.psm >> 8 while check: if check % 2 != 0: raise ValueError('invalid PSM') check >>= 8 - self.servers[psm] = server + self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) - return psm + return self.servers[spec.psm] + @deprecated("Please use create_le_credit_based_server()") def register_le_coc_server( self, psm: int, - server: Callable[[LeConnectionOrientedChannel], Any], - max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, - mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + server: Callable[[LeCreditBasedChannel], Any], + max_credits: int, + mtu: int, + mps: int, ) -> int: - self.check_le_coc_parameters(max_credits, mtu, mps) + return self.create_le_credit_based_server( + spec=LeCreditBasedChannelSpec( + psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits + ), + handler=server, + ).psm - if psm == 0: + def create_le_credit_based_server( + self, + spec: LeCreditBasedChannelSpec, + handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None, + ) -> LeCreditBasedChannelServer: + if spec.psm is None: # Find a free PSM for candidate in range( L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 ): if candidate in self.le_coc_servers: continue - psm = candidate + spec.psm = candidate break else: raise InvalidStateError('no free PSM') else: # Check that the PSM isn't already in use - if psm in self.le_coc_servers: + if spec.psm in self.le_coc_servers: raise ValueError('PSM already in use') - self.le_coc_servers[psm] = ( - server, - max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, - mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer( + self, + spec.psm, + handler, + max_credits=spec.max_credits, + mtu=spec.mtu, + mps=spec.mps, ) - return psm + return self.le_coc_servers[spec.psm] def on_disconnection(self, connection_handle: int, _reason: int) -> None: logger.debug(f'disconnection from {connection_handle}, cleaning up channels') @@ -1650,13 +1744,13 @@ class ChannelManager: logger.debug( f'creating server channel with cid={source_cid} for psm {request.psm}' ) - channel = Channel( - self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU + channel = ClassicChannel( + self, connection, cid, request.psm, source_cid, server.mtu ) connection_channels[source_cid] = channel # Notify - server(channel) + server.on_connection(channel) channel.on_connection_request(request) else: logger.warning( @@ -1878,7 +1972,7 @@ class ChannelManager: self, connection: Connection, cid: int, request ) -> None: if request.le_psm in self.le_coc_servers: - (server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] + server = self.le_coc_servers[request.le_psm] # Check that the CID isn't already used le_connection_channels = self.le_coc_channels.setdefault( @@ -1892,8 +1986,8 @@ class ChannelManager: L2CAP_LE_Credit_Based_Connection_Response( identifier=request.identifier, destination_cid=0, - mtu=mtu, - mps=mps, + mtu=server.mtu, + mps=server.mps, initial_credits=0, # pylint: disable=line-too-long result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, @@ -1911,8 +2005,8 @@ class ChannelManager: L2CAP_LE_Credit_Based_Connection_Response( identifier=request.identifier, destination_cid=0, - mtu=mtu, - mps=mps, + mtu=server.mtu, + mps=server.mps, initial_credits=0, # pylint: disable=line-too-long result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, @@ -1925,18 +2019,18 @@ class ChannelManager: f'creating LE CoC server channel with cid={source_cid} for psm ' f'{request.le_psm}' ) - channel = LeConnectionOrientedChannel( + channel = LeCreditBasedChannel( self, connection, request.le_psm, source_cid, request.source_cid, - mtu, - mps, + server.mtu, + server.mps, request.initial_credits, request.mtu, request.mps, - max_credits, + server.max_credits, True, ) connection_channels[source_cid] = channel @@ -1949,16 +2043,16 @@ class ChannelManager: L2CAP_LE_Credit_Based_Connection_Response( identifier=request.identifier, destination_cid=source_cid, - mtu=mtu, - mps=mps, - initial_credits=max_credits, + mtu=server.mtu, + mps=server.mps, + initial_credits=server.max_credits, # pylint: disable=line-too-long result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, ), ) # Notify - server(channel) + server.on_connection(channel) else: logger.info( f'No LE server for connection 0x{connection.handle:04X} ' @@ -2013,37 +2107,51 @@ class ChannelManager: channel.on_credits(credit.credits) - def on_channel_closed(self, channel: Channel) -> None: + def on_channel_closed(self, channel: ClassicChannel) -> None: connection_channels = self.channels.get(channel.connection.handle) if connection_channels: if channel.source_cid in connection_channels: del connection_channels[channel.source_cid] + @deprecated("Please use create_le_credit_based_channel()") async def open_le_coc( self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int - ) -> LeConnectionOrientedChannel: - self.check_le_coc_parameters(max_credits, mtu, mps) + ) -> LeCreditBasedChannel: + return await self.create_le_credit_based_channel( + connection=connection, + spec=LeCreditBasedChannelSpec( + psm=psm, max_credits=max_credits, mtu=mtu, mps=mps + ), + ) + async def create_le_credit_based_channel( + self, + connection: Connection, + spec: LeCreditBasedChannelSpec, + ) -> LeCreditBasedChannel: # Find a free CID for the new channel connection_channels = self.channels.setdefault(connection.handle, {}) source_cid = self.find_free_le_cid(connection_channels) if source_cid is None: # Should never happen! raise RuntimeError('all CIDs already in use') + if spec.psm is None: + raise ValueError('PSM cannot be None') + # Create the channel - logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}') - channel = LeConnectionOrientedChannel( + logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}') + channel = LeCreditBasedChannel( manager=self, connection=connection, - le_psm=psm, + le_psm=spec.psm, source_cid=source_cid, destination_cid=0, - mtu=mtu, - mps=mps, + mtu=spec.mtu, + mps=spec.mps, credits=0, peer_mtu=0, peer_mps=0, - peer_credits=max_credits, + peer_credits=spec.max_credits, connected=False, ) connection_channels[source_cid] = channel @@ -2062,7 +2170,15 @@ class ChannelManager: return channel - async def connect(self, connection: Connection, psm: int) -> Channel: + @deprecated("Please use create_classic_channel()") + async def connect(self, connection: Connection, psm: int) -> ClassicChannel: + return await self.create_classic_channel( + connection=connection, spec=ClassicChannelSpec(psm=psm) + ) + + async def create_classic_channel( + self, connection: Connection, spec: ClassicChannelSpec + ) -> ClassicChannel: # NOTE: this implementation hard-codes BR/EDR # Find a free CID for a new channel @@ -2071,10 +2187,20 @@ class ChannelManager: if source_cid is None: # Should never happen! raise RuntimeError('all CIDs already in use') + if spec.psm is None: + raise ValueError('PSM cannot be None') + # Create the channel - logger.debug(f'creating client channel with cid={source_cid} for psm {psm}') - channel = Channel( - self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU + logger.debug( + f'creating client channel with cid={source_cid} for psm {spec.psm}' + ) + channel = ClassicChannel( + self, + connection, + L2CAP_SIGNALING_CID, + spec.psm, + source_cid, + spec.mtu, ) connection_channels[source_cid] = channel @@ -2086,3 +2212,20 @@ class ChannelManager: raise e return channel + + +# ----------------------------------------------------------------------------- +# Deprecated Classes +# ----------------------------------------------------------------------------- + + +class Channel(ClassicChannel): + @deprecated("Please use ClassicChannel") + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +class LeConnectionOrientedChannel(LeCreditBasedChannel): + @deprecated("Please use LeCreditBasedChannel") + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 02c18fa..4002dc7 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -674,7 +674,7 @@ class Multiplexer(EventEmitter): acceptor: Optional[Callable[[int], bool]] dlcs: Dict[int, DLC] - def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None: + def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None: super().__init__() self.role = role self.l2cap_channel = l2cap_channel @@ -887,7 +887,7 @@ class Multiplexer(EventEmitter): # ----------------------------------------------------------------------------- class Client: multiplexer: Optional[Multiplexer] - l2cap_channel: Optional[l2cap.Channel] + l2cap_channel: Optional[l2cap.ClassicChannel] def __init__(self, device: Device, connection: Connection) -> None: self.device = device @@ -960,11 +960,11 @@ class Server(EventEmitter): self.acceptors[channel] = acceptor return channel - def on_connection(self, l2cap_channel: l2cap.Channel) -> None: + def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: logger.debug(f'+++ new L2CAP connection: {l2cap_channel}') l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) - def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: + def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') # Create a new multiplexer for the channel diff --git a/bumble/sdp.py b/bumble/sdp.py index 6428187..01e72da 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -758,7 +758,7 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU): # ----------------------------------------------------------------------------- class Client: - channel: Optional[l2cap.Channel] + channel: Optional[l2cap.ClassicChannel] def __init__(self, device: Device) -> None: self.device = device @@ -921,7 +921,7 @@ class Client: # ----------------------------------------------------------------------------- class Server: CONTINUATION_STATE = bytes([0x01, 0x43]) - channel: Optional[l2cap.Channel] + channel: Optional[l2cap.ClassicChannel] Service = NewType('Service', List[ServiceAttribute]) service_records: Dict[int, Service] current_response: Union[None, bytes, Tuple[int, List[int]]] diff --git a/bumble/utils.py b/bumble/utils.py index 03b201c..a562618 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -21,6 +21,7 @@ import logging import traceback import collections import sys +import warnings from typing import ( Awaitable, Set, @@ -427,3 +428,19 @@ def wrap_async(function): Wraps the provided function in an async function. """ return partial(async_call, function) + + +def deprecated(msg: str): + """ + Throw deprecation warning before execution + """ + + def wrapper(function): + @wraps(function) + def inner(*args, **kwargs): + warnings.warn(msg, DeprecationWarning) + return function(*args, **kwargs) + + return inner + + return wrapper