forked from auracaster/bumble_mirror
use a dict instead of a series of ifs (+6 squashed commits)
Squashed commits: [90f2024] fix import order [0edd321] add a few docstrings [77a0ac0] wip [adcf159] wip [96cbd67] wip [d8bfbab] wip (+1 squashed commit) Squashed commits: [43b4d66] wip (+2 squashed commits) Squashed commits: [3dafaa8] wip [5844026] wip (+1 squashed commit) Squashed commits: [4cbb35a] wip (+1 squashed commit) Squashed commits: [4d2b6d3] wip (+4 squashed commits) Squashed commits: [f2da510] wip [318c119] wip [923b4eb] wip [9d46365] wip use a dict instead of a series of ifs (+6 squashed commits) Squashed commits: [90f2024] fix import order [0edd321] add a few docstrings [77a0ac0] wip [adcf159] wip [96cbd67] wip [d8bfbab] wip
This commit is contained in:
291
bumble/avctp.py
Normal file
291
bumble/avctp.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Callable, cast, Dict, Optional
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble import avc
|
||||
from bumble import l2cap
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AVCTP_PSM = 0x0017
|
||||
AVCTP_BROWSING_PSM = 0x001B
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MessageAssembler:
|
||||
Callback = Callable[[int, bool, bool, int, bytes], None]
|
||||
|
||||
transaction_label: int
|
||||
pid: int
|
||||
c_r: int
|
||||
ipid: int
|
||||
payload: bytes
|
||||
number_of_packets: int
|
||||
packets_received: int
|
||||
|
||||
def __init__(self, callback: Callback) -> None:
|
||||
self.callback = callback
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.packets_received = 0
|
||||
self.transaction_label = -1
|
||||
self.pid = -1
|
||||
self.c_r = -1
|
||||
self.ipid = -1
|
||||
self.payload = b''
|
||||
self.number_of_packets = 0
|
||||
self.packet_count = 0
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.packets_received += 1
|
||||
|
||||
transaction_label = pdu[0] >> 4
|
||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||
c_r = (pdu[0] >> 1) & 1
|
||||
ipid = pdu[0] & 1
|
||||
|
||||
if c_r == 0 and ipid != 0:
|
||||
logger.warning("invalid IPID in command frame")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
pid_offset = 1
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
|
||||
if self.transaction_label >= 0:
|
||||
# We are already in a transaction
|
||||
logger.warning("received START or SINGLE fragment while in transaction")
|
||||
self.reset()
|
||||
self.packets_received = 1
|
||||
|
||||
if packet_type == Protocol.PacketType.START:
|
||||
self.number_of_packets = pdu[1]
|
||||
pid_offset = 2
|
||||
|
||||
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
|
||||
self.payload += pdu[pid_offset + 2 :]
|
||||
|
||||
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
|
||||
if transaction_label != self.transaction_label:
|
||||
logger.warning("transaction label does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if pid != self.pid:
|
||||
logger.warning("PID does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if c_r != self.c_r:
|
||||
logger.warning("C/R does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if self.packets_received > self.number_of_packets:
|
||||
logger.warning("too many fragments in transaction")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if packet_type == Protocol.PacketType.END:
|
||||
if self.packets_received != self.number_of_packets:
|
||||
logger.warning("premature END")
|
||||
self.reset()
|
||||
return
|
||||
else:
|
||||
self.transaction_label = transaction_label
|
||||
self.c_r = c_r
|
||||
self.ipid = ipid
|
||||
self.pid = pid
|
||||
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
|
||||
self.on_message_complete()
|
||||
|
||||
def on_message_complete(self):
|
||||
try:
|
||||
self.callback(
|
||||
self.transaction_label,
|
||||
self.c_r == 0,
|
||||
self.ipid != 0,
|
||||
self.pid,
|
||||
self.payload,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(color(f"!!! exception in callback: {error}", "red"))
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Protocol:
|
||||
CommandHandler = Callable[[int, avc.CommandFrame], None]
|
||||
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
|
||||
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
|
||||
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
|
||||
next_transaction_label: int
|
||||
message_assembler: MessageAssembler
|
||||
|
||||
class PacketType(IntEnum):
|
||||
SINGLE = 0b00
|
||||
START = 0b01
|
||||
CONTINUE = 0b10
|
||||
END = 0b11
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
self.command_handlers = {}
|
||||
self.response_handlers = {}
|
||||
self.l2cap_channel = l2cap_channel
|
||||
self.message_assembler = MessageAssembler(self.on_message)
|
||||
|
||||
# Register to receive PDUs from the channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
l2cap_channel.on("open", self.on_l2cap_channel_open)
|
||||
l2cap_channel.on("close", self.on_l2cap_channel_close)
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
logger.debug(color("<<< AVCTP channel open", "magenta"))
|
||||
|
||||
def on_l2cap_channel_close(self):
|
||||
logger.debug(color("<<< AVCTP channel closed", "magenta"))
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.message_assembler.on_pdu(pdu)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"<<< AVCTP Message: pid={pid}, "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"is_command={is_command}, "
|
||||
f"ipid={ipid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
|
||||
# Check for invalid PID responses.
|
||||
if ipid:
|
||||
logger.debug(f"received IPID for PID={pid}")
|
||||
|
||||
# Find the appropriate handler.
|
||||
if is_command:
|
||||
if pid not in self.command_handlers:
|
||||
logger.warning(f"no command handler for PID {pid}")
|
||||
self.send_ipid(transaction_label, pid)
|
||||
return
|
||||
|
||||
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
|
||||
self.command_handlers[pid](transaction_label, command_frame)
|
||||
else:
|
||||
if pid not in self.response_handlers:
|
||||
logger.warning(f"no response handler for PID {pid}")
|
||||
return
|
||||
|
||||
# By convention, for an ipid, send a None payload to the response handler.
|
||||
if ipid:
|
||||
response_frame = None
|
||||
else:
|
||||
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
|
||||
|
||||
self.response_handlers[pid](transaction_label, response_frame)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
# TODO: fragment large messages
|
||||
packet_type = Protocol.PacketType.SINGLE
|
||||
pdu = (
|
||||
struct.pack(
|
||||
">BH",
|
||||
transaction_label << 4
|
||||
| packet_type << 2
|
||||
| (0 if is_command else 1) << 1
|
||||
| (1 if ipid else 0),
|
||||
pid,
|
||||
)
|
||||
+ payload
|
||||
)
|
||||
self.l2cap_channel.send_pdu(pdu)
|
||||
|
||||
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP command: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, True, False, pid, payload)
|
||||
|
||||
def send_response(self, transaction_label: int, pid: int, payload: bytes):
|
||||
logger.debug(
|
||||
">>> AVCTP response: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, False, False, pid, payload)
|
||||
|
||||
def send_ipid(self, transaction_label: int, pid: int) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
|
||||
)
|
||||
self.send_message(transaction_label, False, True, pid, b'')
|
||||
|
||||
def register_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
self.command_handlers[pid] = handler
|
||||
|
||||
def unregister_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
||||
raise ValueError("command handler not registered")
|
||||
del self.command_handlers[pid]
|
||||
|
||||
def register_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
self.response_handlers[pid] = handler
|
||||
|
||||
def unregister_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
||||
raise ValueError("response handler not registered")
|
||||
del self.response_handlers[pid]
|
||||
Reference in New Issue
Block a user