# Copyright 2021-2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import logging import struct from collections.abc import Callable from enum import IntEnum from bumble import core, l2cap from bumble.colors import color # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- AVCTP_PSM = 0x0017 AVCTP_BROWSING_PSM = 0x001B # ----------------------------------------------------------------------------- class MessageAssembler: Callback = Callable[[int, bool, bool, int, bytes], None] transaction_label: int pid: int c_r: int ipid: int payload: bytes number_of_packets: int packets_received: int def __init__(self, callback: Callback) -> None: self.callback = callback self.reset() def reset(self) -> None: self.packets_received = 0 self.transaction_label = -1 self.pid = -1 self.c_r = -1 self.ipid = -1 self.payload = b'' self.number_of_packets = 0 self.packet_count = 0 def on_pdu(self, pdu: bytes) -> None: self.packets_received += 1 transaction_label = pdu[0] >> 4 packet_type = Protocol.PacketType((pdu[0] >> 2) & 3) c_r = (pdu[0] >> 1) & 1 ipid = pdu[0] & 1 if c_r == 0 and ipid != 0: logger.warning("invalid IPID in command frame") self.reset() return pid_offset = 1 if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START): if self.transaction_label >= 0: # We are already in a transaction logger.warning("received START or SINGLE fragment while in transaction") self.reset() self.packets_received = 1 if packet_type == Protocol.PacketType.START: self.number_of_packets = pdu[1] pid_offset = 2 pid = struct.unpack_from(">H", pdu, pid_offset)[0] self.payload += pdu[pid_offset + 2 :] if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END): if transaction_label != self.transaction_label: logger.warning("transaction label does not match") self.reset() return if pid != self.pid: logger.warning("PID does not match") self.reset() return if c_r != self.c_r: logger.warning("C/R does not match") self.reset() return if self.packets_received > self.number_of_packets: logger.warning("too many fragments in transaction") self.reset() return if packet_type == Protocol.PacketType.END: if self.packets_received != self.number_of_packets: logger.warning("premature END") self.reset() return else: self.transaction_label = transaction_label self.c_r = c_r self.ipid = ipid self.pid = pid if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END): self.on_message_complete() def on_message_complete(self): try: self.callback( self.transaction_label, self.c_r == 0, self.ipid != 0, self.pid, self.payload, ) except Exception: logger.exception(color("!!! exception in callback", "red")) self.reset() # ----------------------------------------------------------------------------- class Protocol: CommandHandler = Callable[[int, bytes], None] command_handlers: dict[int, CommandHandler] # Command handlers, by PID ResponseHandler = Callable[[int, bytes | None], None] response_handlers: dict[int, ResponseHandler] # Response handlers, by PID next_transaction_label: int message_assembler: MessageAssembler class PacketType(IntEnum): SINGLE = 0b00 START = 0b01 CONTINUE = 0b10 END = 0b11 def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None: self.command_handlers = {} self.response_handlers = {} self.l2cap_channel = l2cap_channel self.message_assembler = MessageAssembler(self.on_message) # Register to receive PDUs from the channel l2cap_channel.sink = self.on_pdu l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open) l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close) def on_l2cap_channel_open(self): logger.debug(color("<<< AVCTP channel open", "magenta")) def on_l2cap_channel_close(self): logger.debug(color("<<< AVCTP channel closed", "magenta")) def on_pdu(self, pdu: bytes) -> None: self.message_assembler.on_pdu(pdu) def on_message( self, transaction_label: int, is_command: bool, ipid: bool, pid: int, payload: bytes, ) -> None: logger.debug( f"<<< AVCTP Message: pid={pid}, " f"transaction_label={transaction_label}, " f"is_command={is_command}, " f"ipid={ipid}, " f"payload={payload.hex()}" ) # Check for invalid PID responses. if ipid: logger.debug(f"received IPID for PID={pid}") # Find the appropriate handler. if is_command: if pid not in self.command_handlers: logger.warning(f"no command handler for PID {pid}") self.send_ipid(transaction_label, pid) return self.command_handlers[pid](transaction_label, payload) else: if pid not in self.response_handlers: logger.warning(f"no response handler for PID {pid}") return # By convention, for an ipid, send a None payload to the response handler. response_payload = None if ipid else payload self.response_handlers[pid](transaction_label, response_payload) def send_message( self, transaction_label: int, is_command: bool, ipid: bool, pid: int, payload: bytes, ): # TODO: fragment large messages packet_type = Protocol.PacketType.SINGLE pdu = ( struct.pack( ">BH", transaction_label << 4 | packet_type << 2 | (0 if is_command else 1) << 1 | (1 if ipid else 0), pid, ) + payload ) self.l2cap_channel.write(pdu) def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None: logger.debug( ">>> AVCTP command: " f"transaction_label={transaction_label}, " f"pid={pid}, " f"payload={payload.hex()}" ) self.send_message(transaction_label, True, False, pid, payload) def send_response(self, transaction_label: int, pid: int, payload: bytes): logger.debug( ">>> AVCTP response: " f"transaction_label={transaction_label}, " f"pid={pid}, " f"payload={payload.hex()}" ) self.send_message(transaction_label, False, False, pid, payload) def send_ipid(self, transaction_label: int, pid: int) -> None: logger.debug( f">>> AVCTP ipid: transaction_label={transaction_label}, pid={pid}" ) self.send_message(transaction_label, False, True, pid, b'') def register_command_handler( self, pid: int, handler: Protocol.CommandHandler ) -> None: self.command_handlers[pid] = handler def unregister_command_handler( self, pid: int, handler: Protocol.CommandHandler ) -> None: if pid not in self.command_handlers or self.command_handlers[pid] != handler: raise core.InvalidArgumentError("command handler not registered") del self.command_handlers[pid] def register_response_handler( self, pid: int, handler: Protocol.ResponseHandler ) -> None: self.response_handlers[pid] = handler def unregister_response_handler( self, pid: int, handler: Protocol.ResponseHandler ) -> None: if pid not in self.response_handlers or self.response_handlers[pid] != handler: raise core.InvalidArgumentError("response handler not registered") del self.response_handlers[pid]