Files
bumble_mirror/bumble/avctp.py
2025-04-26 03:55:31 +08:00

293 lines
9.7 KiB
Python

# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import core
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise core.InvalidArgumentError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise core.InvalidArgumentError("response handler not registered")
del self.response_handlers[pid]