Compare commits

...

9 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
aa9af61cbe improve exception messages 2023-03-20 12:14:28 -07:00
Gilles Boccon-Gibod
dc3ac3060e add auto-snooping for transports 2023-03-20 11:06:50 -07:00
Gilles Boccon-Gibod
e77723a5f9 Merge pull request #135 from google/gbg/snoop
add snoop support
2023-03-07 09:16:33 -08:00
Gilles Boccon-Gibod
fe8cf51432 Merge pull request #139 from google/gbg/hotfix-001
two small hotfixes
2023-03-07 09:16:15 -08:00
Gilles Boccon-Gibod
97a0e115ae two small hotfixes 2023-03-05 20:24:16 -08:00
Lucas Abel
46e7aac77c Merge pull request #138 from rahularya50/aryarahul/fix-att-perms
Add support for ATT permissions on server-side
2023-03-03 16:18:45 -08:00
Rahul Arya
08a6f4fa49 Add support for ATT permissions on server-side 2023-03-03 16:11:33 -08:00
Lucas Abel
ca063eda0b Merge pull request #132 from rahularya50/aryarahul/fix-uuid
Fix UUID byte-order in serialization
2023-03-03 15:48:50 -08:00
Rahul Arya
c97ba4319f Fix UUID byte-order in serialization 2023-03-03 22:38:21 +00:00
11 changed files with 356 additions and 68 deletions

View File

@@ -25,12 +25,14 @@
from __future__ import annotations
import struct
from pyee import EventEmitter
from typing import Dict, Type
from typing import Dict, Type, TYPE_CHECKING
from bumble.core import UUID, name_or_number
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
if TYPE_CHECKING:
from bumble.device import Connection
# -----------------------------------------------------------------------------
# Constants
@@ -749,7 +751,25 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes):
return value_bytes
def read_value(self, connection):
def read_value(self, connection: Connection):
if (
self.permissions & self.READ_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.READ_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
if read := getattr(self.value, 'read', None):
try:
value = read(connection) # pylint: disable=not-callable
@@ -762,7 +782,25 @@ class Attribute(EventEmitter):
return self.encode_value(value)
def write_value(self, connection, value_bytes):
def write_value(self, connection: Connection, value_bytes):
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.WRITE_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
value = self.decode_value(value_bytes)
if write := getattr(self.value, 'write', None):

View File

@@ -144,9 +144,12 @@ class ConnectionError(BaseError): # pylint: disable=redefined-builtin
class UUID:
'''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
Note that this class expects and works in little-endian byte-order throughout.
The exception is when interacting with strings, which are in big-endian byte-order.
'''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None):
@@ -209,13 +212,20 @@ class UUID:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128=False):
if len(self.uuid_bytes) == 16 or not force_128:
'''
Serialize UUID in little-endian byte-order
'''
if not force_128:
return self.uuid_bytes
if len(self.uuid_bytes) == 4:
return self.uuid_bytes + UUID.BASE_UUID
return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID
if len(self.uuid_bytes) == 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
elif len(self.uuid_bytes) == 4:
return self.BASE_UUID + self.uuid_bytes
elif len(self.uuid_bytes) == 16:
return self.uuid_bytes
else:
assert False, "unreachable"
def to_pdu_bytes(self):
'''

View File

@@ -534,6 +534,9 @@ class Connection(CompositeEventEmitter):
def on_connection_parameters_update_failure(self, error):
pass
def on_connection_data_length_change(self):
pass
def on_connection_phy_update(self):
pass
@@ -2008,7 +2011,7 @@ class Device(CompositeEventEmitter):
NOTE: the name of the parameters may look odd, but it just follows the names
used in the Bluetooth spec.
'''
await self.send_command(
result = await self.send_command(
HCI_LE_Connection_Update_Command(
connection_handle=connection.handle,
connection_interval_min=connection_interval_min,
@@ -2017,9 +2020,10 @@ class Device(CompositeEventEmitter):
supervision_timeout=supervision_timeout,
min_ce_length=min_ce_length,
max_ce_length=max_ce_length,
),
check_result=True,
)
)
if result.status != HCI_Command_Status_Event.PENDING:
raise HCI_StatusError(result)
async def get_connection_rssi(self, connection):
result = await self.send_command(

View File

@@ -61,7 +61,6 @@ from .att import (
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_INCLUDE_ATTRIBUTE_TYPE,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
@@ -543,8 +542,6 @@ class Server(EventEmitter):
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
):
# TODO: check permissions
this_uuid_size = len(attribute.type.to_pdu_bytes())
if attributes:
@@ -638,6 +635,13 @@ class Server(EventEmitter):
'''
pdu_space_available = connection.att_mtu - 2
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
attributes = []
for attribute in (
attribute
@@ -647,10 +651,21 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# TODO: check permissions
try:
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
if not attributes:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=attribute.handle,
error_code=error.error_code,
)
break
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
@@ -676,11 +691,7 @@ class Server(EventEmitter):
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
logging.warning(f"not found {request}")
self.send_response(connection, response)
@@ -690,10 +701,17 @@ class Server(EventEmitter):
'''
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(attribute_value=value[:value_size])
try:
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
)
else:
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(attribute_value=value[:value_size])
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -708,29 +726,36 @@ class Server(EventEmitter):
'''
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
if request.value_offset > len(value):
try:
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
error_code=error.error_code,
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
if request.value_offset > len(value):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -746,7 +771,6 @@ class Server(EventEmitter):
if request.attribute_group_type not in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE,
):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -766,8 +790,10 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# Check the attribute value size
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate

View File

@@ -1421,7 +1421,11 @@ class HCI_Constant:
# -----------------------------------------------------------------------------
class HCI_Error(ProtocolError):
def __init__(self, error_code):
super().__init__(error_code, 'hci', HCI_Constant.error_name(error_code))
super().__init__(
error_code,
error_namespace='hci',
error_name=HCI_Constant.error_name(error_code),
)
# -----------------------------------------------------------------------------

View File

@@ -276,7 +276,7 @@ class Host(AbortableEventEmitter):
def send_hci_packet(self, packet):
if self.snooper:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(packet.to_bytes())
@@ -425,7 +425,7 @@ class Host(AbortableEventEmitter):
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET:

View File

@@ -15,12 +15,21 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum
import logging
import struct
import datetime
from typing import BinaryIO
from typing import BinaryIO, Generator
import os
from bumble.hci import HCI_Packet, HCI_COMMAND_PACKET, HCI_EVENT_PACKET
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@@ -44,7 +53,7 @@ class Snooper:
HCI_BSCP = 1003
H5 = 1004
def snoop(self, hci_packet: HCI_Packet, direction: Direction) -> None:
def snoop(self, hci_packet: bytes, direction: Direction) -> None:
"""Snoop on an HCI packet."""
@@ -67,9 +76,10 @@ class BtSnooper(Snooper):
self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4)
)
def snoop(self, hci_packet: HCI_Packet, direction: Snooper.Direction) -> None:
def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None:
flags = int(direction)
if hci_packet.hci_packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
packet_type = hci_packet[0]
if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
flags |= 0x10
# Compute the current timestamp
@@ -79,15 +89,82 @@ class BtSnooper(Snooper):
)
# Emit the record
packet_data = bytes(hci_packet)
self.output.write(
struct.pack(
'>IIIIQ',
len(packet_data), # Original Length
len(packet_data), # Included Length
len(hci_packet), # Original Length
len(hci_packet), # Included Length
flags, # Packet Flags
0, # Cumulative Drops
timestamp, # Timestamp
)
+ packet_data
+ hci_packet
)
# -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0
@contextmanager
def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
Create a snooper given a specification string.
The general syntax for the specification string is:
<snooper-type>:<type-specific-arguments>
Supported snooper types are:
btsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.utcnow()`
pid: the current process ID.
instance: the instance ID in the current process.
Examples:
btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
# Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.utcnow(),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open the file
logger.debug(f'Snoop file: {file_path}')
with open(file_path, 'wb') as snoop_file:
_SNOOPER_INSTANCE_COUNT += 1
yield BtSnooper(snoop_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import asynccontextmanager
import logging
import os
from .common import Transport, AsyncPipeSink
from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper
# -----------------------------------------------------------------------------
# Logging
@@ -26,14 +29,53 @@ from ..controller import Controller
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def _wrap_transport(transport: Transport) -> Transport:
"""
Automatically wrap a Transport instance when a wrapping class can be inferred
from the environment.
If no wrapping class is applicable, the transport argument is returned as-is.
"""
# If BUMBLE_SNOOPER is set, try to automatically create a snooper.
if snooper_spec := os.getenv('BUMBLE_SNOOPER'):
try:
return SnoopingTransport.create_with(
transport, create_snooper(snooper_spec)
)
except Exception as exc:
logger.warning(f'Exception while creating snooper: {exc}')
return transport
# -----------------------------------------------------------------------------
async def open_transport(name: str) -> Transport:
'''
"""
Open a transport by name.
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The supported types are: serial,udp,tcp,pty,usb
'''
The supported types are:
* serial
* udp
* tcp-client
* tcp-server
* ws-client
* ws-server
* pty
* file
* vhci
* hci-socket
* usb
* pyusb
* android-emulator
"""
return _wrap_transport(await _open_transport(name))
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
@@ -107,7 +149,18 @@ async def open_transport(name: str) -> Transport:
# -----------------------------------------------------------------------------
async def open_transport_or_link(name):
async def open_transport_or_link(name: str) -> Transport:
"""
Open a transport or a link relay.
Args:
name:
Name of the transport or link relay to open.
When the name starts with "link-relay:", open a link relay (see RemoteLink
for details on what the arguments are).
For other namespaces, see `open_transport`.
"""
if name.startswith('link-relay:'):
from ..link import RemoteLink # lazy import
@@ -119,6 +172,6 @@ async def open_transport_or_link(name):
async def close(self):
link.close()
return LinkTransport(controller, AsyncPipeSink(controller))
return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller)))
return await open_transport(name)

View File

@@ -15,12 +15,16 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import contextlib
import struct
import asyncio
import logging
from typing import ContextManager
from .. import hci
from ..colors import color
from ..snoop import Snooper
# -----------------------------------------------------------------------------
@@ -246,6 +250,20 @@ class StreamPacketSink:
# -----------------------------------------------------------------------------
class Transport:
"""
Base class for all transports.
A Transport represents a source and a sink together.
An instance must be closed by calling close() when no longer used. Instances
implement the ContextManager protocol so that they may be used in a `async with`
statement.
An instance is iterable. The iterator yields, in order, its source and sink, so
that it may be used with a convenient call syntax like:
async with create_transport() as (source, sink):
...
"""
def __init__(self, source, sink):
self.source = source
self.sink = sink
@@ -335,3 +353,60 @@ class PumpedTransport(Transport):
async def close(self):
await super().close()
await self.close_function()
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
@staticmethod
def create_with(
transport: Transport, snooper: ContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.
The returned instance will exit the snooper context when it is closed.
"""
with contextlib.ExitStack() as exit_stack:
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source:
def __init__(self, source, snooper):
self.source = source
self.snooper = snooper
self.sink = None
def set_packet_sink(self, sink):
self.sink = sink
self.source.set_packet_sink(self)
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink:
self.sink.on_packet(packet)
class Sink:
def __init__(self, sink, snooper):
self.sink = sink
self.snooper = snooper
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink:
self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None):
super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
)
self.transport = transport
self.close_snooper = close_snooper
async def close(self):
await self.transport.close()
if self.close_snooper:
self.close_snooper()

View File

@@ -72,7 +72,7 @@ test =
development =
black == 22.10
invoke >= 1.7.3
mypy == 0.991
mypy == 1.1.1
nox >= 2022
pylint == 2.15.8
types-appdirs >= 1.4.3

View File

@@ -72,5 +72,6 @@ def test_parser_extensions():
# -----------------------------------------------------------------------------
test_parser()
test_parser_extensions()
if __name__ == '__main__':
test_parser()
test_parser_extensions()