Compare commits

..

15 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
fe28473ba8 Merge pull request #234 from zxzxwu/addr
Support address resolution offload
2023-08-08 21:30:13 -07:00
Gilles Boccon-Gibod
53d66bc74a Merge pull request #237 from marshallpierce/mp/company-ids
Faster company id table
2023-08-08 21:29:45 -07:00
Marshall Pierce
e2c1ad5342 Faster company id table
Following up on the [loose end from the initial
PR](https://github.com/google/bumble/pull/207#discussion_r1278015116),
we can avoid accessing the Python company id map at runtime by doing
code gen ahead of time.

Using an example to do the code gen avoids even the small build slowdown
from invoking the code gen logic in build.rs, but more importantly,
means that it's still a totally boring normal build that won't require
any IDE setup, etc, to work for everyone. Since the company ID list
changes rarely, and there's a test to ensure it always matches, this
seems like a good trade.
2023-08-04 10:12:52 -06:00
Josh Wu
6399c5fb04 Auto add device to resolving list after pairing 2023-08-03 20:51:00 +08:00
Josh Wu
784cf4f26a Add a flag to enable LE address resolution 2023-08-03 20:50:57 +08:00
Josh Wu
0301b1a999 Pandora: Configure identity address type 2023-08-02 11:31:07 -07:00
Lucas Abel
3ab2cd5e71 pandora: decrease all info logs to debug 2023-08-02 10:56:41 -07:00
uael
6ea669531a pandora: add tcp option to transport configuration
* Add a fallback to `tcp` when `transport` is not set.
* Default the `tcp` transport to the default rootcanal HCI address.
2023-08-01 08:51:12 -07:00
Josh Wu
cbbada4748 SMP: Delegate distributed address type 2023-08-01 08:38:03 -07:00
Gilles Boccon-Gibod
152b8d1233 Merge pull request #230 from google/gbg/hci-object-array
add support for field arrays in hci packet definitions
2023-08-01 07:44:31 -07:00
Gilles Boccon-Gibod
bdad225033 add support for field arrays in hci packet definitions 2023-07-30 22:19:10 -07:00
Gilles Boccon-Gibod
8eeb58e467 Merge pull request #207 from marshallpierce/mp/rust-poc
Proof-of-concept Rust wrapper
2023-07-28 20:14:23 -07:00
Marshall Pierce
91971433d2 PR feedback 2023-07-28 14:34:02 -06:00
Gilles Boccon-Gibod
a0a4bd457f Merge pull request #227 from google/gbg/py11
compatibility with python 11
2023-07-28 12:54:30 -07:00
Marshall Pierce
afb21220e2 Proof-of-concept Rust wrapper
This contains Rust wrappers around enough of the Python API to implement Rust versions of the `battery_client` and `run_scanner` examples. The goal is to gather feedback on the approach, and of course to show that it is possible.

The module structure mirrors that of the Python. The Rust API is not optimally Rust-y, but given the constraints of everything having to delegate to Python, it's at least usable.

Notably, this does not yet solve the packaging problem: users must have an appropriate virtualenv, libpython, etc. [PyOxidizer](https://github.com/indygreg/PyOxidizer) may be a viable path there.
2023-07-20 10:50:15 -06:00
38 changed files with 6836 additions and 332 deletions

View File

@@ -41,3 +41,30 @@ jobs:
run: |
inv build
inv build.mkdocs
build-rust:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: clippy,rustfmt
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings
- name: Rust Build
run: cd rust && cargo build --all-targets
- name: Rust Tests
run: cd rust && cargo test

1
.gitignore vendored
View File

@@ -9,3 +9,4 @@ __pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
/.idea

View File

@@ -86,6 +86,7 @@ from .hci import (
HCI_LE_Extended_Create_Connection_Command,
HCI_LE_Rand_Command,
HCI_LE_Read_PHY_Command,
HCI_LE_Set_Address_Resolution_Enable_Command,
HCI_LE_Set_Advertising_Data_Command,
HCI_LE_Set_Advertising_Enable_Command,
HCI_LE_Set_Advertising_Parameters_Command,
@@ -778,6 +779,7 @@ class DeviceConfiguration:
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False
def load_from_dict(self, config: Dict[str, Any]) -> None:
# Load simple properties
@@ -1029,6 +1031,7 @@ class Device(CompositeEventEmitter):
self.discoverable = config.discoverable
self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload
for service in config.gatt_services:
characteristics = []
@@ -1256,31 +1259,16 @@ class Device(CompositeEventEmitter):
)
# Load the address resolving list
if self.keystore and self.host.supports_command(
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND
):
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
if self.keystore:
await self.refresh_resolving_list()
resolving_keys = await self.keystore.get_resolving_keys()
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)
# Enable address resolution
# await self.send_command(
# HCI_LE_Set_Address_Resolution_Enable_Command(
# address_resolution_enable=1)
# )
# )
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
# Enable address resolution
if self.address_resolution_offload:
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
) # type: ignore[call-arg]
)
if self.classic_enabled:
await self.send_command(
@@ -1310,6 +1298,26 @@ class Device(CompositeEventEmitter):
await self.host.flush()
self.powered_on = False
async def refresh_resolving_list(self) -> None:
assert self.keystore is not None
resolving_keys = await self.keystore.get_resolving_keys()
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
if self.address_resolution_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)
def supports_le_feature(self, feature):
return self.host.supports_le_feature(feature)

View File

@@ -1445,8 +1445,14 @@ class HCI_Object:
@staticmethod
def init_from_fields(hci_object, fields, values):
if isinstance(values, dict):
for field_name, _ in fields:
setattr(hci_object, field_name, values[field_name])
for field in fields:
if isinstance(field, list):
# The field is an array, up-level the array field names
for sub_field_name, _ in field:
setattr(hci_object, sub_field_name, values[sub_field_name])
else:
field_name = field[0]
setattr(hci_object, field_name, values[field_name])
else:
for field_name, field_value in zip(fields, values):
setattr(hci_object, field_name, field_value)
@@ -1456,133 +1462,161 @@ class HCI_Object:
parsed = HCI_Object.dict_from_bytes(data, offset, fields)
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod
def parse_field(data, offset, field_type):
# The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict):
if 'size' in field_type:
field_type = field_type['size']
elif 'parser' in field_type:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
if field_type == 1:
# 8-bit unsigned
return (data[offset], 1)
if field_type == -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
if field_type == 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
if field_type == '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
if field_type == -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
if field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
if field_type == 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
if field_type == '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if callable(field_type):
new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset)
raise ValueError(f'unknown field type {field_type}')
@staticmethod
def dict_from_bytes(data, offset, fields):
result = collections.OrderedDict()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict):
if 'size' in field_type:
field_type = field_type['size']
elif 'parser' in field_type:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
offset += len(field_value)
elif field_type == 1:
# 8-bit unsigned
field_value = data[offset]
for field in fields:
if isinstance(field, list):
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
elif field_type == -1:
# 8-bit signed
field_value = struct.unpack_from('b', data, offset)[0]
offset += 1
elif field_type == 2:
# 16-bit unsigned
field_value = struct.unpack_from('<H', data, offset)[0]
offset += 2
elif field_type == '>2':
# 16-bit unsigned big-endian
field_value = struct.unpack_from('>H', data, offset)[0]
offset += 2
elif field_type == -2:
# 16-bit signed
field_value = struct.unpack_from('<h', data, offset)[0]
offset += 2
elif field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
field_value = struct.unpack('<I', padded)[0]
offset += 3
elif field_type == 4:
# 32-bit unsigned
field_value = struct.unpack_from('<I', data, offset)[0]
offset += 4
elif field_type == '>4':
# 32-bit unsigned big-endian
field_value = struct.unpack_from('>I', data, offset)[0]
offset += 4
elif isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
field_value = data[offset : offset + field_type]
offset += field_type
elif callable(field_type):
offset, field_value = field_type(data, offset)
else:
raise ValueError(f'unknown field type {field_type}')
for _ in range(item_count):
for sub_field_name, sub_field_type in field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
result.setdefault(sub_field_name, []).append(value)
offset += size
continue
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(data, offset, field_type)
result[field_name] = field_value
offset += field_size
return result
@staticmethod
def serialize_field(field_value, field_type):
# The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size
serializer = None
if isinstance(field_type, dict):
if 'serializer' in field_type:
serializer = field_type['serializer']
if 'size' in field_type:
field_type = field_type['size']
# Serialize the field
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(f"don't know how to serialize type {type(field_value)}")
return field_bytes
@staticmethod
def dict_to_bytes(hci_object, fields):
result = bytearray()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size
serializer = None
if isinstance(field_type, dict):
if 'serializer' in field_type:
serializer = field_type['serializer']
if 'size' in field_type:
field_type = field_type['size']
# Serialize the field
field_value = hci_object[field_name]
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or Pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(
f"don't know how to serialize type {type(field_value)}"
for field in fields:
if isinstance(field, list):
# The field is an array. The serialized form starts with a 1-byte
# item count. We use the length of the first array field as the
# array count, since all array fields have the same number of items.
item_count = len(hci_object[field[0][0]])
result += bytes([item_count]) + b''.join(
b''.join(
HCI_Object.serialize_field(
hci_object[sub_field_name][i], sub_field_type
)
for sub_field_name, sub_field_type in field
)
for i in range(item_count)
)
continue
result += field_bytes
(field_name, field_type) = field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result)
@@ -1617,48 +1651,73 @@ class HCI_Object:
return str(value)
@staticmethod
def format_fields(hci_object, keys, indentation='', value_mappers=None):
if not keys:
return ''
def stringify_field(
field_name, field_type, field_value, indentation, value_mappers
):
value_mapper = None
if isinstance(field_type, dict):
# Get the value mapper from the specifier
value_mapper = field_type.get('mapper')
# Measure the widest field name
max_field_name_length = max(
(len(key[0] if isinstance(key, tuple) else key) for key in keys)
# Check if there's a matching mapper passed
if value_mappers:
value_mapper = value_mappers.get(field_name, value_mapper)
# Map the value if we have a mapper
if value_mapper is not None:
field_value = value_mapper(field_value)
# Get the string representation of the value
return HCI_Object.format_field_value(
field_value, indentation=indentation + ' '
)
@staticmethod
def format_fields(hci_object, fields, indentation='', value_mappers=None):
if not fields:
return ''
# Build array of formatted key:value pairs
fields = []
for key in keys:
value_mapper = None
if isinstance(key, tuple):
# The key has an associated specifier
key, specifier = key
field_strings = []
for field in fields:
if isinstance(field, list):
for sub_field in field:
sub_field_name, sub_field_type = sub_field
item_count = len(hci_object[sub_field_name])
for i in range(item_count):
field_strings.append(
(
f'{sub_field_name}[{i}]',
HCI_Object.stringify_field(
sub_field_name,
sub_field_type,
hci_object[sub_field_name][i],
indentation,
value_mappers,
),
),
)
continue
# Get the value mapper from the specifier
if isinstance(specifier, dict):
value_mapper = specifier.get('mapper')
# Get the value for the field
value = hci_object[key]
# Check if there's a matching mapper passed
if value_mappers:
value_mapper = value_mappers.get(key, value_mapper)
# Map the value if we have a mapper
if value_mapper is not None:
value = value_mapper(value)
# Get the string representation of the value
value_str = HCI_Object.format_field_value(
value, indentation=indentation + ' '
field_name, field_type = field
field_value = hci_object[field_name]
field_strings.append(
(
field_name,
HCI_Object.stringify_field(
field_name, field_type, field_value, indentation, value_mappers
),
),
)
# Add the field to the formatted result
key_str = color(f'{key + ":":{1 + max_field_name_length}}', 'cyan')
fields.append(f'{indentation}{key_str} {value_str}')
return '\n'.join(fields)
# Measure the widest field name
max_field_name_length = max(len(s[0]) for s in field_strings)
sep = ':'
return '\n'.join(
f'{indentation}'
f'{color(f"{field_name + sep:{1 + max_field_name_length}}", "cyan")} {field_value}'
for field_name, field_value in field_strings
)
def __bytes__(self):
return self.to_bytes()
@@ -3769,9 +3828,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command):
'advertising_data',
{
'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial(
HCI_Object.serialize_length_prefixed_bytes
),
'serializer': HCI_Object.serialize_length_prefixed_bytes,
},
),
]
@@ -3819,9 +3876,7 @@ class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command):
'scan_response_data',
{
'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial(
HCI_Object.serialize_length_prefixed_bytes
),
'serializer': HCI_Object.serialize_length_prefixed_bytes,
},
),
]
@@ -3849,73 +3904,21 @@ class HCI_LE_Set_Extended_Scan_Response_Data_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(fields=None)
@HCI_Command.command(
[
('enable', 1),
[
('advertising_handles', 1),
('durations', 2),
('max_extended_advertising_events', 1),
],
]
)
class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command
'''
@classmethod
def from_parameters(cls, parameters):
enable = parameters[0]
num_sets = parameters[1]
advertising_handles = []
durations = []
max_extended_advertising_events = []
offset = 2
for _ in range(num_sets):
advertising_handles.append(parameters[offset])
durations.append(struct.unpack_from('<H', parameters, offset + 1)[0])
max_extended_advertising_events.append(parameters[offset + 3])
offset += 4
return cls(
enable, advertising_handles, durations, max_extended_advertising_events
)
def __init__(
self, enable, advertising_handles, durations, max_extended_advertising_events
):
super().__init__(HCI_LE_SET_EXTENDED_ADVERTISING_ENABLE_COMMAND)
self.enable = enable
self.advertising_handles = advertising_handles
self.durations = durations
self.max_extended_advertising_events = max_extended_advertising_events
self.parameters = bytes([enable, len(advertising_handles)]) + b''.join(
[
struct.pack(
'<BHB',
advertising_handles[i],
durations[i],
max_extended_advertising_events[i],
)
for i in range(len(advertising_handles))
]
)
def __str__(self):
fields = [('enable:', self.enable)]
for i, advertising_handle in enumerate(self.advertising_handles):
fields.append(
(f'advertising_handle[{i}]: ', advertising_handle)
)
fields.append((f'duration[{i}]: ', self.durations[i]))
fields.append(
(
f'max_extended_advertising_events[{i}]:',
self.max_extended_advertising_events[i],
)
)
return (
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
)
)
# -----------------------------------------------------------------------------
@HCI_Command.command(
@@ -4066,7 +4069,10 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command):
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
[
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
)
)
@@ -4242,7 +4248,10 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
[
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
)
)
@@ -5205,7 +5214,7 @@ class HCI_Number_Of_Completed_Packets_Event(HCI_Event):
def __str__(self):
lines = [
color(self.name, 'magenta') + ':',
color(' number_of_handles: ', 'cyan')
color(' number_of_handles: ', 'cyan')
+ f'{len(self.connection_handles)}',
]
for i, connection_handle in enumerate(self.connection_handles):

View File

@@ -19,6 +19,7 @@ import enum
from typing import Optional, Tuple
from .hci import (
Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
@@ -168,21 +169,28 @@ class PairingDelegate:
class PairingConfig:
"""Configuration for the Pairing protocol."""
class AddressType(enum.IntEnum):
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}])'
)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from bumble.pairing import PairingDelegate
from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict
@@ -20,6 +20,7 @@ from typing import Any, Dict
@dataclass
class Config:
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
identity_address_type: PairingConfig.AddressType = PairingConfig.AddressType.RANDOM
pairing_sc_enable: bool = True
pairing_mitm_enable: bool = True
pairing_bonding_enable: bool = True
@@ -35,6 +36,12 @@ class Config:
'io_capability', 'no_output_no_input'
).upper()
self.io_capability = getattr(PairingDelegate, io_capability_name)
identity_address_type_name: str = config.get(
'identity_address_type', 'random'
).upper()
self.identity_address_type = getattr(
PairingConfig.AddressType, identity_address_type_name
)
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)

View File

@@ -34,6 +34,10 @@ from bumble.sdp import (
from typing import Any, Dict, List, Optional
# Default rootcanal HCI TCP address
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
class PandoraDevice:
"""
Small wrapper around a Bumble device and it's HCI transport.
@@ -53,7 +57,9 @@ class PandoraDevice:
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.device = _make_device(config)
self._hci_name = config.get('transport', '')
self._hci_name = config.get(
'transport', f"tcp-client:{config.get('tcp', ROOTCANAL_HCI_ADDRESS)}"
)
self._hci = None
@property

View File

@@ -112,7 +112,7 @@ class HostService(HostServicer):
async def FactoryReset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('FactoryReset')
self.log.debug('FactoryReset')
# delete all bonds
if self.device.keystore is not None:
@@ -126,7 +126,7 @@ class HostService(HostServicer):
async def Reset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('Reset')
self.log.debug('Reset')
# clear service.
self.waited_connections.clear()
@@ -139,7 +139,7 @@ class HostService(HostServicer):
async def ReadLocalAddress(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> ReadLocalAddressResponse:
self.log.info('ReadLocalAddress')
self.log.debug('ReadLocalAddress')
return ReadLocalAddressResponse(
address=bytes(reversed(bytes(self.device.public_address)))
)
@@ -152,7 +152,7 @@ class HostService(HostServicer):
address = Address(
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
)
self.log.info(f"Connect to {address}")
self.log.debug(f"Connect to {address}")
try:
connection = await self.device.connect(
@@ -167,7 +167,7 @@ class HostService(HostServicer):
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"Connect to {address} done (handle={connection.handle})")
self.log.debug(f"Connect to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectResponse(connection=Connection(cookie=cookie))
@@ -186,7 +186,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"WaitConnection from {address}...")
self.log.debug(f"WaitConnection from {address}...")
connection = self.device.find_connection_by_bd_addr(
address, transport=BT_BR_EDR_TRANSPORT
@@ -201,7 +201,7 @@ class HostService(HostServicer):
# save connection has waited and respond.
self.waited_connections.add(id(connection))
self.log.info(
self.log.debug(
f"WaitConnection from {address} done (handle={connection.handle})"
)
@@ -216,7 +216,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"ConnectLE to {address}...")
self.log.debug(f"ConnectLE to {address}...")
try:
connection = await self.device.connect(
@@ -233,7 +233,7 @@ class HostService(HostServicer):
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
self.log.debug(f"ConnectLE to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectLEResponse(connection=Connection(cookie=cookie))
@@ -243,12 +243,12 @@ class HostService(HostServicer):
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Disconnect: {connection_handle}")
self.log.debug(f"Disconnect: {connection_handle}")
self.log.info("Disconnecting...")
self.log.debug("Disconnecting...")
if connection := self.device.lookup_connection(connection_handle):
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
self.log.info("Disconnected")
self.log.debug("Disconnected")
return empty_pb2.Empty()
@@ -257,7 +257,7 @@ class HostService(HostServicer):
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitDisconnection: {connection_handle}")
self.log.debug(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[
@@ -270,7 +270,7 @@ class HostService(HostServicer):
connection.on('disconnection', on_disconnection)
try:
await disconnection_future
self.log.info("Disconnected")
self.log.debug("Disconnected")
finally:
connection.remove_listener('disconnection', on_disconnection) # type: ignore
@@ -378,7 +378,7 @@ class HostService(HostServicer):
try:
while True:
if not self.device.is_advertising:
self.log.info('Advertise')
self.log.debug('Advertise')
await self.device.start_advertising(
target=target,
advertising_type=advertising_type,
@@ -393,10 +393,10 @@ class HostService(HostServicer):
bumble.device.Connection
] = asyncio.get_running_loop().create_future()
self.log.info('Wait for LE connection...')
self.log.debug('Wait for LE connection...')
connection = await pending_connection
self.log.info(
self.log.debug(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
)
@@ -410,7 +410,7 @@ class HostService(HostServicer):
self.device.remove_listener('connection', on_connection) # type: ignore
try:
self.log.info('Stop advertising')
self.log.debug('Stop advertising')
await self.device.abort_on('flush', self.device.stop_advertising())
except:
pass
@@ -423,7 +423,7 @@ class HostService(HostServicer):
if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
self.log.info('Scan')
self.log.debug('Scan')
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait)
@@ -470,7 +470,7 @@ class HostService(HostServicer):
finally:
self.device.remove_listener('advertisement', handler) # type: ignore
try:
self.log.info('Stop scanning')
self.log.debug('Stop scanning')
await self.device.abort_on('flush', self.device.stop_scanning())
except:
pass
@@ -479,7 +479,7 @@ class HostService(HostServicer):
async def Inquiry(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> AsyncGenerator[InquiryResponse, None]:
self.log.info('Inquiry')
self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[
Optional[Tuple[Address, int, AdvertisingData, int]]
@@ -510,7 +510,7 @@ class HostService(HostServicer):
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
try:
self.log.info('Stop inquiry')
self.log.debug('Stop inquiry')
await self.device.abort_on('flush', self.device.stop_discovery())
except:
pass
@@ -519,7 +519,7 @@ class HostService(HostServicer):
async def SetDiscoverabilityMode(
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetDiscoverabilityMode")
self.log.debug("SetDiscoverabilityMode")
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
return empty_pb2.Empty()
@@ -527,7 +527,7 @@ class HostService(HostServicer):
async def SetConnectabilityMode(
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetConnectabilityMode")
self.log.debug("SetConnectabilityMode")
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
return empty_pb2.Empty()

View File

@@ -99,7 +99,7 @@ class PairingDelegate(BasePairingDelegate):
return ev
async def confirm(self, auto: bool = False) -> bool:
self.log.info(
self.log.debug(
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
)
@@ -114,7 +114,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
self.log.info(
self.log.debug(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
)
@@ -129,7 +129,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm
async def get_number(self) -> Optional[int]:
self.log.info(
self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
@@ -146,7 +146,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
self.log.info(
self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
@@ -177,7 +177,7 @@ class PairingDelegate(BasePairingDelegate):
):
return
self.log.info(
self.log.debug(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
)
@@ -232,6 +232,7 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
identity_address_type=config.identity_address_type,
delegate=PairingDelegate(
connection,
self,
@@ -247,7 +248,7 @@ class SecurityService(SecurityServicer):
async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]:
self.log.info('OnPairing')
self.log.debug('OnPairing')
if self.event_queue is not None:
raise RuntimeError('already streaming pairing events')
@@ -273,7 +274,7 @@ class SecurityService(SecurityServicer):
self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Secure: {connection_handle}")
self.log.debug(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -291,7 +292,7 @@ class SecurityService(SecurityServicer):
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
self.log.info('Pair...')
self.log.debug('Pair...')
if (
connection.transport == BT_LE_TRANSPORT
@@ -309,7 +310,7 @@ class SecurityService(SecurityServicer):
else:
await connection.pair()
self.log.info('Paired')
self.log.debug('Paired')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -320,9 +321,9 @@ class SecurityService(SecurityServicer):
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
self.log.info('Authenticate...')
self.log.debug('Authenticate...')
await connection.authenticate()
self.log.info('Authenticated')
self.log.debug('Authenticated')
except asyncio.CancelledError:
self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -333,9 +334,9 @@ class SecurityService(SecurityServicer):
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
self.log.info('Encrypt...')
self.log.debug('Encrypt...')
await connection.encrypt()
self.log.info('Encrypted')
self.log.debug('Encrypted')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -353,7 +354,7 @@ class SecurityService(SecurityServicer):
self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitSecurity: {connection_handle}")
self.log.debug(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -390,7 +391,7 @@ class SecurityService(SecurityServicer):
def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None:
self.log.info(f'Wait for security: error `{name}`: {args}')
self.log.debug(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
return wrapper
@@ -398,13 +399,13 @@ class SecurityService(SecurityServicer):
def try_set_success(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
elif (
connection.transport == BT_BR_EDR_TRANSPORT
@@ -432,7 +433,7 @@ class SecurityService(SecurityServicer):
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.info('Wait for security...')
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
@@ -442,12 +443,12 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any
if authenticate_task is not None:
self.log.info('Wait for authentication...')
self.log.debug('Wait for authentication...')
try:
await authenticate_task # type: ignore
except:
pass
self.log.info('Authenticated')
self.log.debug('Authenticated')
return WaitSecurityResponse(**kwargs)
@@ -503,7 +504,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"IsBonded: {address}")
self.log.debug(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
@@ -517,7 +518,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"DeleteBond: {address}")
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
with suppress(KeyError):

View File

@@ -993,6 +993,19 @@ class Session:
)
)
def send_identity_address_command(self) -> None:
identity_address = {
None: self.connection.self_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
bd_addr=identity_address,
)
)
def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
@@ -1016,6 +1029,7 @@ class Session:
self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self) -> None:
# Distribute the keys as required
if self.is_initiator:
# CTKD: Derive LTK from LinkKey
@@ -1045,12 +1059,7 @@ class Session:
identity_resolving_key=self.manager.device.irk
)
)
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
self.send_identity_address_command()
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
@@ -1094,12 +1103,7 @@ class Session:
identity_resolving_key=self.manager.device.irk
)
)
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
self.send_identity_address_command()
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
@@ -1268,7 +1272,7 @@ class Session:
keys.link_key = PairingKeys.Key(
value=self.link_key, authenticated=authenticated
)
self.manager.on_pairing(self, peer_address, keys)
await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
@@ -1823,20 +1827,13 @@ class Manager(EventEmitter):
def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection)
def on_pairing(
async def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
async def store_keys():
try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
self.device.abort_on('flush', store_keys())
await self.device.keystore.update(str(identity_address), keys)
await self.device.refresh_resolving_list()
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)

2
rust/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
/target
/.idea

7
rust/CHANGELOG.md Normal file
View File

@@ -0,0 +1,7 @@
# Next
- Code-gen company ID table
# 0.1.0
- Initial release

1194
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

56
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,56 @@
[package]
name = "bumble"
description = "Rust API for the Bumble Bluetooth stack"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
homepage = "https://google.github.io/bumble/index.html"
repository = "https://github.com/google/bumble"
documentation = "https://docs.rs/crate/bumble"
authors = ["Marshall Pierce <marshallpierce@google.com>"]
keywords = ["bluetooth", "ble"]
categories = ["api-bindings", "network-programming"]
rust-version = "1.69.0"
[dependencies]
pyo3 = { version = "0.18.3", features = ["macros"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime"] }
tokio = { version = "1.28.2" }
nom = "7.1.3"
strum = "0.25.0"
strum_macros = "0.25.0"
hex = "0.4.3"
itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
anyhow = { version = "1.0.71", optional = true }
[dev-dependencies]
tokio = { version = "1.28.2", features = ["full"] }
tempfile = "3.6.0"
nix = "0.26.2"
anyhow = "1.0.71"
pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] }
clap = { version = "4.3.3", features = ["derive"] }
owo-colors = "3.5.0"
log = "0.4.19"
env_logger = "0.10.0"
rusb = "0.9.2"
rand = "0.8.5"
[[bin]]
name = "gen-assigned-numbers"
path = "tools/gen_assigned_numbers.rs"
required-features = ["bumble-dev-tools"]
# test entry point that uses pyo3_asyncio's test harness
[[test]]
name = "pytests"
path = "pytests/pytests.rs"
harness = false
[features]
anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
bumble-dev-tools = ["dep:anyhow"]

56
rust/README.md Normal file
View File

@@ -0,0 +1,56 @@
# What is this?
Rust wrappers around the [Bumble](https://github.com/google/bumble) Python API.
Method calls are mapped to the equivalent Python, and return types adapted where
relevant.
See the `examples` directory for usage.
# Usage
Set up a virtualenv for Bumble, or otherwise have an isolated Python environment
for Bumble and its dependencies.
Due to Python being
[picky about how its sys path is set up](https://github.com/PyO3/pyo3/issues/1741,
it's necessary to explicitly point to the virtualenv's `site-packages`. Use
suitable virtualenv paths as appropriate for your OS, as seen here running
the `battery_client` example:
```
PYTHONPATH=..:~/.virtualenvs/bumble/lib/python3.10/site-packages/ \
cargo run --example battery_client -- \
--transport android-netsim --target-addr F0:F1:F2:F3:F4:F5
```
Run the corresponding `battery_server` Python example, and launch an emulator in
Android Studio (currently, Canary is required) to run netsim.
# Development
Run the tests:
```
PYTHONPATH=.. cargo test
```
Check lints:
```
cargo clippy --all-targets
```
## Code gen
To have the fastest startup while keeping the build simple, code gen for
assigned numbers is done with the `gen_assigned_numbers` tool. It should
be re-run whenever the Python assigned numbers are changed. To ensure that the
generated code is kept up to date, the Rust data is compared to the Python
in tests at `pytests/assigned_numbers.rs`.
To regenerate the assigned number tables based on the Python codebase:
```
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-dev-tools
```

View File

@@ -0,0 +1,112 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Counterpart to the Python example `battery_server.py`.
//!
//! Start an Android emulator from Android Studio, or otherwise have netsim running.
//!
//! Run the server from the project root:
//! ```
//! PYTHONPATH=. python examples/battery_server.py \
//! examples/device1.json android-netsim
//! ```
//!
//! Then run this example from the `rust` directory:
//!
//! ```
//! PYTHONPATH=..:/path/to/virtualenv/site-packages/ \
//! cargo run --example battery_client -- \
//! --transport android-netsim \
//! --target-addr F0:F1:F2:F3:F4:F5
//! ```
use bumble::wrapper::{
device::{Device, Peer},
profile::BatteryServiceProxy,
transport::Transport,
PyObjectExt,
};
use clap::Parser as _;
use log::info;
use owo_colors::OwoColorize;
use pyo3::prelude::*;
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let cli = Cli::parse();
let transport = Transport::open(cli.transport).await?;
let device = Device::with_hci(
"Bumble",
"F0:F1:F2:F3:F4:F5",
transport.source()?,
transport.sink()?,
)?;
device.power_on().await?;
let conn = device.connect(&cli.target_addr).await?;
let mut peer = Peer::new(conn)?;
for mut s in peer.discover_services().await? {
s.discover_characteristics().await?;
}
let battery_service = peer
.create_service_proxy::<BatteryServiceProxy>()?
.ok_or(anyhow::anyhow!("No battery service found"))?;
let mut battery_level_char = battery_service
.battery_level()?
.ok_or(anyhow::anyhow!("No battery level characteristic"))?;
info!(
"{} {}",
"Initial Battery Level:".green(),
battery_level_char
.read_value()
.await?
.extract_with_gil::<u32>()?
);
battery_level_char
.subscribe(|_py, args| {
info!(
"{} {:?}",
"Battery level update:".green(),
args.get_item(0)?.extract::<u32>()?,
);
Ok(())
})
.await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
Ok(())
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Address to connect to
#[arg(long)]
target_addr: String,
}

View File

@@ -0,0 +1,98 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::anyhow;
use bumble::{
adv::{AdvertisementDataBuilder, CommonDataType},
wrapper::{
device::Device,
logging::{bumble_env_logging_level, py_logging_basic_config},
transport::Transport,
},
};
use clap::Parser as _;
use pyo3::PyResult;
use rand::Rng;
use std::path;
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let cli = Cli::parse();
if cli.log_hci {
py_logging_basic_config(bumble_env_logging_level("DEBUG"))?;
}
let transport = Transport::open(cli.transport).await?;
let mut device = Device::from_config_file_with_hci(
&cli.device_config,
transport.source()?,
transport.sink()?,
)?;
let mut adv_data = AdvertisementDataBuilder::new();
adv_data
.append(
CommonDataType::CompleteLocalName,
"Bumble from Rust".as_bytes(),
)
.map_err(|e| anyhow!(e))?;
// Randomized TX power
adv_data
.append(
CommonDataType::TxPowerLevel,
&[rand::thread_rng().gen_range(-100_i8..=20) as u8],
)
.map_err(|e| anyhow!(e))?;
device.set_advertising_data(adv_data)?;
device.power_on().await?;
println!("Advertising...");
device.start_advertising(true).await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
println!("Stopping...");
device.stop_advertising().await?;
Ok(())
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Bumble device config.
///
/// See, for instance, `examples/device1.json` in the Python project.
#[arg(long)]
device_config: path::PathBuf,
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Log HCI commands
#[arg(long)]
log_hci: bool,
}

185
rust/examples/scanner.rs Normal file
View File

@@ -0,0 +1,185 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Counterpart to the Python example `run_scanner.py`.
//!
//! Device deduplication is done here rather than relying on the controller's filtering to provide
//! for additional features, like the ability to make deduplication time-bounded.
use bumble::{
adv::CommonDataType,
wrapper::{
core::AdvertisementDataUnit, device::Device, hci::AddressType, transport::Transport,
},
};
use clap::Parser as _;
use itertools::Itertools;
use owo_colors::{OwoColorize, Style};
use pyo3::PyResult;
use std::{
collections,
sync::{Arc, Mutex},
time,
};
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let cli = Cli::parse();
let transport = Transport::open(cli.transport).await?;
let mut device = Device::with_hci(
"Bumble",
"F0:F1:F2:F3:F4:F5",
transport.source()?,
transport.sink()?,
)?;
// in practice, devices can send multiple advertisements from the same address, so we keep
// track of a timestamp for each set of data
let seen_advertisements = Arc::new(Mutex::new(collections::HashMap::<
Vec<u8>,
collections::HashMap<Vec<AdvertisementDataUnit>, time::Instant>,
>::new()));
let seen_adv_clone = seen_advertisements.clone();
device.on_advertisement(move |_py, adv| {
let rssi = adv.rssi()?;
let data_units = adv.data()?.data_units()?;
let addr = adv.address()?;
let show_adv = if cli.filter_duplicates {
let addr_bytes = addr.as_le_bytes()?;
let mut seen_adv_cache = seen_adv_clone.lock().unwrap();
let expiry_duration = time::Duration::from_secs(cli.dedup_expiry_secs);
let advs_from_addr = seen_adv_cache
.entry(addr_bytes)
.or_insert_with(collections::HashMap::new);
// we expect cache hits to be the norm, so we do a separate lookup to avoid cloning
// on every lookup with entry()
let show = if let Some(prev) = advs_from_addr.get_mut(&data_units) {
let expired = prev.elapsed() > expiry_duration;
*prev = time::Instant::now();
expired
} else {
advs_from_addr.insert(data_units.clone(), time::Instant::now());
true
};
// clean out anything we haven't seen in a while
advs_from_addr.retain(|_, instant| instant.elapsed() <= expiry_duration);
show
} else {
true
};
if !show_adv {
return Ok(());
}
let addr_style = if adv.is_connectable()? {
Style::new().yellow()
} else {
Style::new().red()
};
let (type_style, qualifier) = match adv.address()?.address_type()? {
AddressType::PublicIdentity | AddressType::PublicDevice => (Style::new().cyan(), ""),
_ => {
if addr.is_static()? {
(Style::new().green(), "(static)")
} else if addr.is_resolvable()? {
(Style::new().magenta(), "(resolvable)")
} else {
(Style::new().default_color(), "")
}
}
};
println!(
">>> {} [{:?}] {qualifier}:\n RSSI: {}",
addr.as_hex()?.style(addr_style),
addr.address_type()?.style(type_style),
rssi,
);
data_units.into_iter().for_each(|(code, data)| {
let matching = CommonDataType::for_type_code(code).collect::<Vec<_>>();
let code_str = if matching.is_empty() {
format!("0x{}", hex::encode_upper([code.into()]))
} else {
matching
.iter()
.map(|t| format!("{}", t))
.join(" / ")
.blue()
.to_string()
};
// use the first matching type's formatted data, if any
let data_str = matching
.iter()
.filter_map(|t| {
t.format_data(&data).map(|formatted| {
format!(
"{} {}",
formatted,
format!("(raw: 0x{})", hex::encode_upper(&data)).dimmed()
)
})
})
.next()
.unwrap_or_else(|| format!("0x{}", hex::encode_upper(&data)));
println!(" [{}]: {}", code_str, data_str)
});
Ok(())
})?;
device.power_on().await?;
// do our own dedup
device.start_scanning(false).await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
Ok(())
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Filter duplicate advertisements
#[arg(long, default_value_t = false)]
filter_duplicates: bool,
/// How long before a deduplicated advertisement that hasn't been seen in a while is considered
/// fresh again, in seconds
#[arg(long, default_value_t = 10, requires = "filter_duplicates")]
dedup_expiry_secs: u64,
}

342
rust/examples/usb_probe.rs Normal file
View File

@@ -0,0 +1,342 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Rust version of the Python `usb_probe.py`.
//!
//! This tool lists all the USB devices, with details about each device.
//! For each device, the different possible Bumble transport strings that can
//! refer to it are listed. If the device is known to be a Bluetooth HCI device,
//! its identifier is printed in reverse colors, and the transport names in cyan color.
//! For other devices, regardless of their type, the transport names are printed
//! in red. Whether that device is actually a Bluetooth device or not depends on
//! whether it is a Bluetooth device that uses a non-standard Class, or some other
//! type of device (there's no way to tell).
use clap::Parser as _;
use itertools::Itertools as _;
use owo_colors::{OwoColorize, Style};
use rusb::{Device, DeviceDescriptor, Direction, TransferType, UsbContext};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
const USB_DEVICE_CLASS_DEVICE: u8 = 0x00;
const USB_DEVICE_CLASS_WIRELESS_CONTROLLER: u8 = 0xE0;
const USB_DEVICE_SUBCLASS_RF_CONTROLLER: u8 = 0x01;
const USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER: u8 = 0x01;
fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let mut bt_dev_count = 0;
let mut device_serials_by_id: HashMap<(u16, u16), HashSet<String>> = HashMap::new();
for device in rusb::devices()?.iter() {
let device_desc = device.device_descriptor().unwrap();
let class_info = ClassInfo::from(&device_desc);
let handle = device.open()?;
let timeout = Duration::from_secs(1);
// some devices don't have languages
let lang = handle
.read_languages(timeout)
.ok()
.and_then(|langs| langs.into_iter().next());
let serial = lang.and_then(|l| {
handle
.read_serial_number_string(l, &device_desc, timeout)
.ok()
});
let mfg = lang.and_then(|l| {
handle
.read_manufacturer_string(l, &device_desc, timeout)
.ok()
});
let product = lang.and_then(|l| handle.read_product_string(l, &device_desc, timeout).ok());
let is_hci = is_bluetooth_hci(&device, &device_desc)?;
let addr_style = if is_hci {
bt_dev_count += 1;
Style::new().black().on_yellow()
} else {
Style::new().yellow().on_black()
};
let mut transport_names = Vec::new();
let basic_transport_name = format!(
"usb:{:04X}:{:04X}",
device_desc.vendor_id(),
device_desc.product_id()
);
if is_hci {
transport_names.push(format!("usb:{}", bt_dev_count - 1));
}
let device_id = (device_desc.vendor_id(), device_desc.product_id());
if !device_serials_by_id.contains_key(&device_id) {
transport_names.push(basic_transport_name.clone());
} else {
transport_names.push(format!(
"{}#{}",
basic_transport_name,
device_serials_by_id
.get(&device_id)
.map(|serials| serials.len())
.unwrap_or(0)
))
}
if let Some(s) = &serial {
if !device_serials_by_id
.get(&device_id)
.map(|serials| serials.contains(s))
.unwrap_or(false)
{
transport_names.push(format!("{}/{}", basic_transport_name, s))
}
}
println!(
"{}",
format!(
"ID {:04X}:{:04X}",
device_desc.vendor_id(),
device_desc.product_id()
)
.style(addr_style)
);
if !transport_names.is_empty() {
let style = if is_hci {
Style::new().cyan()
} else {
Style::new().red()
};
println!(
"{:26}{}",
" Bumble Transport Names:".blue(),
transport_names.iter().map(|n| n.style(style)).join(" or ")
)
}
println!(
"{:26}{:03}/{:03}",
" Bus/Device:".green(),
device.bus_number(),
device.address()
);
println!(
"{:26}{}",
" Class:".green(),
class_info.formatted_class_name()
);
println!(
"{:26}{}",
" Subclass/Protocol:".green(),
class_info.formatted_subclass_protocol()
);
if let Some(s) = serial {
println!("{:26}{}", " Serial:".green(), s);
device_serials_by_id
.entry(device_id)
.or_insert(HashSet::new())
.insert(s);
}
if let Some(m) = mfg {
println!("{:26}{}", " Manufacturer:".green(), m);
}
if let Some(p) = product {
println!("{:26}{}", " Product:".green(), p);
}
if cli.verbose {
print_device_details(&device, &device_desc)?;
}
println!();
}
Ok(())
}
fn is_bluetooth_hci<T: UsbContext>(
device: &Device<T>,
device_desc: &DeviceDescriptor,
) -> rusb::Result<bool> {
if device_desc.class_code() == USB_DEVICE_CLASS_WIRELESS_CONTROLLER
&& device_desc.sub_class_code() == USB_DEVICE_SUBCLASS_RF_CONTROLLER
&& device_desc.protocol_code() == USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
{
Ok(true)
} else if device_desc.class_code() == USB_DEVICE_CLASS_DEVICE {
for i in 0..device_desc.num_configurations() {
for interface in device.config_descriptor(i)?.interfaces() {
for d in interface.descriptors() {
if d.class_code() == USB_DEVICE_CLASS_WIRELESS_CONTROLLER
&& d.sub_class_code() == USB_DEVICE_SUBCLASS_RF_CONTROLLER
&& d.protocol_code() == USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
{
return Ok(true);
}
}
}
}
Ok(false)
} else {
Ok(false)
}
}
fn print_device_details<T: UsbContext>(
device: &Device<T>,
device_desc: &DeviceDescriptor,
) -> anyhow::Result<()> {
for i in 0..device_desc.num_configurations() {
println!(" Configuration {}", i + 1);
for interface in device.config_descriptor(i)?.interfaces() {
let interface_descriptors: Vec<_> = interface.descriptors().collect();
for d in &interface_descriptors {
let class_info =
ClassInfo::new(d.class_code(), d.sub_class_code(), d.protocol_code());
println!(
" Interface: {}{} ({}, {})",
interface.number(),
if interface_descriptors.len() > 1 {
format!("/{}", d.setting_number())
} else {
String::new()
},
class_info.formatted_class_name(),
class_info.formatted_subclass_protocol()
);
for e in d.endpoint_descriptors() {
println!(
" Endpoint {:#04X}: {} {}",
e.address(),
match e.transfer_type() {
TransferType::Control => "CONTROL",
TransferType::Isochronous => "ISOCHRONOUS",
TransferType::Bulk => "BULK",
TransferType::Interrupt => "INTERRUPT",
},
match e.direction() {
Direction::In => "IN",
Direction::Out => "OUT",
}
)
}
}
}
}
Ok(())
}
struct ClassInfo {
class: u8,
sub_class: u8,
protocol: u8,
}
impl ClassInfo {
fn new(class: u8, sub_class: u8, protocol: u8) -> Self {
Self {
class,
sub_class,
protocol,
}
}
fn class_name(&self) -> Option<&str> {
match self.class {
0x00 => Some("Device"),
0x01 => Some("Audio"),
0x02 => Some("Communications and CDC Control"),
0x03 => Some("Human Interface Device"),
0x05 => Some("Physical"),
0x06 => Some("Still Imaging"),
0x07 => Some("Printer"),
0x08 => Some("Mass Storage"),
0x09 => Some("Hub"),
0x0A => Some("CDC Data"),
0x0B => Some("Smart Card"),
0x0D => Some("Content Security"),
0x0E => Some("Video"),
0x0F => Some("Personal Healthcare"),
0x10 => Some("Audio/Video"),
0x11 => Some("Billboard"),
0x12 => Some("USB Type-C Bridge"),
0x3C => Some("I3C"),
0xDC => Some("Diagnostic"),
USB_DEVICE_CLASS_WIRELESS_CONTROLLER => Some("Wireless Controller"),
0xEF => Some("Miscellaneous"),
0xFE => Some("Application Specific"),
0xFF => Some("Vendor Specific"),
_ => None,
}
}
fn protocol_name(&self) -> Option<&str> {
match self.class {
USB_DEVICE_CLASS_WIRELESS_CONTROLLER => match self.sub_class {
0x01 => match self.protocol {
0x01 => Some("Bluetooth"),
0x02 => Some("UWB"),
0x03 => Some("Remote NDIS"),
0x04 => Some("Bluetooth AMP"),
_ => None,
},
_ => None,
},
_ => None,
}
}
fn formatted_class_name(&self) -> String {
self.class_name()
.map(|s| s.to_string())
.unwrap_or_else(|| format!("{:#04X}", self.class))
}
fn formatted_subclass_protocol(&self) -> String {
format!(
"{}/{}{}",
self.sub_class,
self.protocol,
self.protocol_name()
.map(|s| format!(" [{}]", s))
.unwrap_or_else(String::new)
)
}
}
impl From<&DeviceDescriptor> for ClassInfo {
fn from(value: &DeviceDescriptor) -> Self {
Self::new(
value.class_code(),
value.sub_class_code(),
value.protocol_code(),
)
}
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Show additional info for each USB device
#[arg(long, default_value_t = false)]
verbose: bool,
}

View File

@@ -0,0 +1,30 @@
use bumble::wrapper::{self, core::Uuid16};
use pyo3::{intern, prelude::*, types::PyDict};
use std::collections;
#[pyo3_asyncio::tokio::test]
async fn company_ids_matches_python() -> PyResult<()> {
let ids_from_python = Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| {
Ok((
Uuid16::from_be_bytes(k.extract::<u16>()?.to_be_bytes()),
v.str()?.to_str()?.to_string(),
))
})
.collect::<PyResult<collections::HashMap<_, _>>>()
})?;
assert_eq!(
wrapper::assigned_numbers::COMPANY_IDS
.iter()
.map(|(id, name)| (*id, name.to_string()))
.collect::<collections::HashMap<_, _>>(),
ids_from_python,
"Company ids do not match -- re-run gen_assigned_numbers?"
);
Ok(())
}

21
rust/pytests/pytests.rs Normal file
View File

@@ -0,0 +1,21 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[pyo3_asyncio::tokio::main]
async fn main() -> pyo3::PyResult<()> {
pyo3_asyncio::testing::main().await
}
mod assigned_numbers;
mod wrapper;

31
rust/pytests/wrapper.rs Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::transport::Transport;
use nix::sys::stat::Mode;
use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn fifo_transport_can_open() -> PyResult<()> {
let dir = tempfile::tempdir().unwrap();
let mut fifo = dir.path().to_path_buf();
fifo.push("bumble-transport-fifo");
nix::unistd::mkfifo(&fifo, Mode::S_IRWXU).unwrap();
let mut t = Transport::open(format!("file:{}", fifo.to_str().unwrap())).await?;
t.close().await?;
Ok(())
}

446
rust/src/adv.rs Normal file
View File

@@ -0,0 +1,446 @@
//! BLE advertisements.
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};
use crate::wrapper::core::{Uuid128, Uuid16, Uuid32};
use itertools::Itertools;
use nom::{combinator, multi, number};
use std::fmt;
use strum::IntoEnumIterator;
/// The numeric code for a common data type.
///
/// For known types, see [CommonDataType], or use this type directly for non-assigned codes.
#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
pub struct CommonDataTypeCode(u8);
impl From<CommonDataType> for CommonDataTypeCode {
fn from(value: CommonDataType) -> Self {
let byte = match value {
CommonDataType::Flags => 0x01,
CommonDataType::IncompleteListOf16BitServiceClassUuids => 0x02,
CommonDataType::CompleteListOf16BitServiceClassUuids => 0x03,
CommonDataType::IncompleteListOf32BitServiceClassUuids => 0x04,
CommonDataType::CompleteListOf32BitServiceClassUuids => 0x05,
CommonDataType::IncompleteListOf128BitServiceClassUuids => 0x06,
CommonDataType::CompleteListOf128BitServiceClassUuids => 0x07,
CommonDataType::ShortenedLocalName => 0x08,
CommonDataType::CompleteLocalName => 0x09,
CommonDataType::TxPowerLevel => 0x0A,
CommonDataType::ClassOfDevice => 0x0D,
CommonDataType::SimplePairingHashC192 => 0x0E,
CommonDataType::SimplePairingRandomizerR192 => 0x0F,
// These two both really have type code 0x10! D:
CommonDataType::DeviceId => 0x10,
CommonDataType::SecurityManagerTkValue => 0x10,
CommonDataType::SecurityManagerOutOfBandFlags => 0x11,
CommonDataType::PeripheralConnectionIntervalRange => 0x12,
CommonDataType::ListOf16BitServiceSolicitationUuids => 0x14,
CommonDataType::ListOf128BitServiceSolicitationUuids => 0x15,
CommonDataType::ServiceData16BitUuid => 0x16,
CommonDataType::PublicTargetAddress => 0x17,
CommonDataType::RandomTargetAddress => 0x18,
CommonDataType::Appearance => 0x19,
CommonDataType::AdvertisingInterval => 0x1A,
CommonDataType::LeBluetoothDeviceAddress => 0x1B,
CommonDataType::LeRole => 0x1C,
CommonDataType::SimplePairingHashC256 => 0x1D,
CommonDataType::SimplePairingRandomizerR256 => 0x1E,
CommonDataType::ListOf32BitServiceSolicitationUuids => 0x1F,
CommonDataType::ServiceData32BitUuid => 0x20,
CommonDataType::ServiceData128BitUuid => 0x21,
CommonDataType::LeSecureConnectionsConfirmationValue => 0x22,
CommonDataType::LeSecureConnectionsRandomValue => 0x23,
CommonDataType::Uri => 0x24,
CommonDataType::IndoorPositioning => 0x25,
CommonDataType::TransportDiscoveryData => 0x26,
CommonDataType::LeSupportedFeatures => 0x27,
CommonDataType::ChannelMapUpdateIndication => 0x28,
CommonDataType::PbAdv => 0x29,
CommonDataType::MeshMessage => 0x2A,
CommonDataType::MeshBeacon => 0x2B,
CommonDataType::BigInfo => 0x2C,
CommonDataType::BroadcastCode => 0x2D,
CommonDataType::ResolvableSetIdentifier => 0x2E,
CommonDataType::AdvertisingIntervalLong => 0x2F,
CommonDataType::ThreeDInformationData => 0x3D,
CommonDataType::ManufacturerSpecificData => 0xFF,
};
Self(byte)
}
}
impl From<u8> for CommonDataTypeCode {
fn from(value: u8) -> Self {
Self(value)
}
}
impl From<CommonDataTypeCode> for u8 {
fn from(value: CommonDataTypeCode) -> Self {
value.0
}
}
/// Data types for assigned type codes.
///
/// See Bluetooth Assigned Numbers § 2.3
#[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter)]
#[allow(missing_docs)]
pub enum CommonDataType {
Flags,
IncompleteListOf16BitServiceClassUuids,
CompleteListOf16BitServiceClassUuids,
IncompleteListOf32BitServiceClassUuids,
CompleteListOf32BitServiceClassUuids,
IncompleteListOf128BitServiceClassUuids,
CompleteListOf128BitServiceClassUuids,
ShortenedLocalName,
CompleteLocalName,
TxPowerLevel,
ClassOfDevice,
SimplePairingHashC192,
SimplePairingRandomizerR192,
DeviceId,
SecurityManagerTkValue,
SecurityManagerOutOfBandFlags,
PeripheralConnectionIntervalRange,
ListOf16BitServiceSolicitationUuids,
ListOf128BitServiceSolicitationUuids,
ServiceData16BitUuid,
PublicTargetAddress,
RandomTargetAddress,
Appearance,
AdvertisingInterval,
LeBluetoothDeviceAddress,
LeRole,
SimplePairingHashC256,
SimplePairingRandomizerR256,
ListOf32BitServiceSolicitationUuids,
ServiceData32BitUuid,
ServiceData128BitUuid,
LeSecureConnectionsConfirmationValue,
LeSecureConnectionsRandomValue,
Uri,
IndoorPositioning,
TransportDiscoveryData,
LeSupportedFeatures,
ChannelMapUpdateIndication,
PbAdv,
MeshMessage,
MeshBeacon,
BigInfo,
BroadcastCode,
ResolvableSetIdentifier,
AdvertisingIntervalLong,
ThreeDInformationData,
ManufacturerSpecificData,
}
impl CommonDataType {
/// Iterate over the zero, one, or more matching types for the provided code.
///
/// `0x10` maps to both Device Id and Security Manager TK Value, so multiple matching types
/// may exist for a single code.
pub fn for_type_code(code: CommonDataTypeCode) -> impl Iterator<Item = CommonDataType> {
Self::iter().filter(move |t| CommonDataTypeCode::from(*t) == code)
}
/// Apply type-specific human-oriented formatting to data, if any is applicable
pub fn format_data(&self, data: &[u8]) -> Option<String> {
match self {
Self::Flags => Some(Flags::matching(data).map(|f| format!("{:?}", f)).join(",")),
Self::CompleteListOf16BitServiceClassUuids
| Self::IncompleteListOf16BitServiceClassUuids
| Self::ListOf16BitServiceSolicitationUuids => {
combinator::complete(multi::many0(Uuid16::parse_le))(data)
.map(|(_res, uuids)| {
uuids
.into_iter()
.map(|uuid| {
SERVICE_IDS
.get(&uuid)
.map(|name| format!("{:?} ({name})", uuid))
.unwrap_or_else(|| format!("{:?}", uuid))
})
.join(", ")
})
.ok()
}
Self::CompleteListOf32BitServiceClassUuids
| Self::IncompleteListOf32BitServiceClassUuids
| Self::ListOf32BitServiceSolicitationUuids => {
combinator::complete(multi::many0(Uuid32::parse))(data)
.map(|(_res, uuids)| uuids.into_iter().map(|u| format!("{:?}", u)).join(", "))
.ok()
}
Self::CompleteListOf128BitServiceClassUuids
| Self::IncompleteListOf128BitServiceClassUuids
| Self::ListOf128BitServiceSolicitationUuids => {
combinator::complete(multi::many0(Uuid128::parse_le))(data)
.map(|(_res, uuids)| uuids.into_iter().map(|u| format!("{:?}", u)).join(", "))
.ok()
}
Self::ServiceData16BitUuid => Uuid16::parse_le(data)
.map(|(rem, uuid)| {
format!(
"service={:?}, data={}",
SERVICE_IDS
.get(&uuid)
.map(|name| format!("{:?} ({name})", uuid))
.unwrap_or_else(|| format!("{:?}", uuid)),
hex::encode_upper(rem)
)
})
.ok(),
Self::ServiceData32BitUuid => Uuid32::parse(data)
.map(|(rem, uuid)| format!("service={:?}, data={}", uuid, hex::encode_upper(rem)))
.ok(),
Self::ServiceData128BitUuid => Uuid128::parse_le(data)
.map(|(rem, uuid)| format!("service={:?}, data={}", uuid, hex::encode_upper(rem)))
.ok(),
Self::ShortenedLocalName | Self::CompleteLocalName => {
std::str::from_utf8(data).ok().map(|s| format!("\"{}\"", s))
}
Self::TxPowerLevel => {
let (_, tx) =
combinator::complete(number::complete::i8::<_, nom::error::Error<_>>)(data)
.ok()?;
Some(tx.to_string())
}
Self::ManufacturerSpecificData => {
let (rem, id) = Uuid16::parse_le(data).ok()?;
Some(format!(
"company={}, data=0x{}",
COMPANY_IDS
.get(&id)
.map(|s| s.to_string())
.unwrap_or_else(|| format!("{:?}", id)),
hex::encode_upper(rem)
))
}
_ => None,
}
}
}
impl fmt::Display for CommonDataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CommonDataType::Flags => write!(f, "Flags"),
CommonDataType::IncompleteListOf16BitServiceClassUuids => {
write!(f, "Incomplete List of 16-bit Service Class UUIDs")
}
CommonDataType::CompleteListOf16BitServiceClassUuids => {
write!(f, "Complete List of 16-bit Service Class UUIDs")
}
CommonDataType::IncompleteListOf32BitServiceClassUuids => {
write!(f, "Incomplete List of 32-bit Service Class UUIDs")
}
CommonDataType::CompleteListOf32BitServiceClassUuids => {
write!(f, "Complete List of 32-bit Service Class UUIDs")
}
CommonDataType::ListOf16BitServiceSolicitationUuids => {
write!(f, "List of 16-bit Service Solicitation UUIDs")
}
CommonDataType::ListOf32BitServiceSolicitationUuids => {
write!(f, "List of 32-bit Service Solicitation UUIDs")
}
CommonDataType::ListOf128BitServiceSolicitationUuids => {
write!(f, "List of 128-bit Service Solicitation UUIDs")
}
CommonDataType::IncompleteListOf128BitServiceClassUuids => {
write!(f, "Incomplete List of 128-bit Service Class UUIDs")
}
CommonDataType::CompleteListOf128BitServiceClassUuids => {
write!(f, "Complete List of 128-bit Service Class UUIDs")
}
CommonDataType::ShortenedLocalName => write!(f, "Shortened Local Name"),
CommonDataType::CompleteLocalName => write!(f, "Complete Local Name"),
CommonDataType::TxPowerLevel => write!(f, "TX Power Level"),
CommonDataType::ClassOfDevice => write!(f, "Class of Device"),
CommonDataType::SimplePairingHashC192 => {
write!(f, "Simple Pairing Hash C-192")
}
CommonDataType::SimplePairingHashC256 => {
write!(f, "Simple Pairing Hash C 256")
}
CommonDataType::SimplePairingRandomizerR192 => {
write!(f, "Simple Pairing Randomizer R-192")
}
CommonDataType::SimplePairingRandomizerR256 => {
write!(f, "Simple Pairing Randomizer R 256")
}
CommonDataType::DeviceId => write!(f, "Device Id"),
CommonDataType::SecurityManagerTkValue => {
write!(f, "Security Manager TK Value")
}
CommonDataType::SecurityManagerOutOfBandFlags => {
write!(f, "Security Manager Out of Band Flags")
}
CommonDataType::PeripheralConnectionIntervalRange => {
write!(f, "Peripheral Connection Interval Range")
}
CommonDataType::ServiceData16BitUuid => {
write!(f, "Service Data 16-bit UUID")
}
CommonDataType::ServiceData32BitUuid => {
write!(f, "Service Data 32-bit UUID")
}
CommonDataType::ServiceData128BitUuid => {
write!(f, "Service Data 128-bit UUID")
}
CommonDataType::PublicTargetAddress => write!(f, "Public Target Address"),
CommonDataType::RandomTargetAddress => write!(f, "Random Target Address"),
CommonDataType::Appearance => write!(f, "Appearance"),
CommonDataType::AdvertisingInterval => write!(f, "Advertising Interval"),
CommonDataType::LeBluetoothDeviceAddress => {
write!(f, "LE Bluetooth Device Address")
}
CommonDataType::LeRole => write!(f, "LE Role"),
CommonDataType::LeSecureConnectionsConfirmationValue => {
write!(f, "LE Secure Connections Confirmation Value")
}
CommonDataType::LeSecureConnectionsRandomValue => {
write!(f, "LE Secure Connections Random Value")
}
CommonDataType::LeSupportedFeatures => write!(f, "LE Supported Features"),
CommonDataType::Uri => write!(f, "URI"),
CommonDataType::IndoorPositioning => write!(f, "Indoor Positioning"),
CommonDataType::TransportDiscoveryData => {
write!(f, "Transport Discovery Data")
}
CommonDataType::ChannelMapUpdateIndication => {
write!(f, "Channel Map Update Indication")
}
CommonDataType::PbAdv => write!(f, "PB-ADV"),
CommonDataType::MeshMessage => write!(f, "Mesh Message"),
CommonDataType::MeshBeacon => write!(f, "Mesh Beacon"),
CommonDataType::BigInfo => write!(f, "BIGIInfo"),
CommonDataType::BroadcastCode => write!(f, "Broadcast Code"),
CommonDataType::ResolvableSetIdentifier => {
write!(f, "Resolvable Set Identifier")
}
CommonDataType::AdvertisingIntervalLong => {
write!(f, "Advertising Interval Long")
}
CommonDataType::ThreeDInformationData => write!(f, "3D Information Data"),
CommonDataType::ManufacturerSpecificData => {
write!(f, "Manufacturer Specific Data")
}
}
}
}
/// Accumulates advertisement data to broadcast on a [crate::wrapper::device::Device].
#[derive(Debug, Clone, Default)]
pub struct AdvertisementDataBuilder {
encoded_data: Vec<u8>,
}
impl AdvertisementDataBuilder {
/// Returns a new, empty instance.
pub fn new() -> Self {
Self {
encoded_data: Vec::new(),
}
}
/// Append advertising data to the builder.
///
/// Returns an error if the data cannot be appended.
pub fn append(
&mut self,
type_code: impl Into<CommonDataTypeCode>,
data: &[u8],
) -> Result<(), AdvertisementDataBuilderError> {
self.encoded_data.push(
data.len()
.try_into()
.ok()
.and_then(|len: u8| len.checked_add(1))
.ok_or(AdvertisementDataBuilderError::DataTooLong)?,
);
self.encoded_data.push(type_code.into().0);
self.encoded_data.extend_from_slice(data);
Ok(())
}
pub(crate) fn into_bytes(self) -> Vec<u8> {
self.encoded_data
}
}
/// Errors that can occur when building advertisement data with [AdvertisementDataBuilder].
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum AdvertisementDataBuilderError {
/// The provided adv data is too long to be encoded
#[error("Data too long")]
DataTooLong,
}
#[derive(PartialEq, Eq, strum_macros::EnumIter)]
#[allow(missing_docs)]
/// Features in the Flags AD
pub enum Flags {
LeLimited,
LeDiscoverable,
NoBrEdr,
BrEdrController,
BrEdrHost,
}
impl fmt::Debug for Flags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.short_name())
}
}
impl Flags {
/// Iterates over the flags that are present in the provided `flags` bytes.
pub fn matching(flags: &[u8]) -> impl Iterator<Item = Self> + '_ {
// The encoding is not clear from the spec: do we look at the first byte? or the last?
// In practice it's only one byte.
let first_byte = flags.first().unwrap_or(&0_u8);
Self::iter().filter(move |f| {
let mask = match f {
Flags::LeLimited => 0x01_u8,
Flags::LeDiscoverable => 0x02,
Flags::NoBrEdr => 0x04,
Flags::BrEdrController => 0x08,
Flags::BrEdrHost => 0x10,
};
mask & first_byte > 0
})
}
/// An abbreviated form of the flag name.
///
/// See [Flags::name] for the full name.
pub fn short_name(&self) -> &'static str {
match self {
Flags::LeLimited => "LE Limited",
Flags::LeDiscoverable => "LE General",
Flags::NoBrEdr => "No BR/EDR",
Flags::BrEdrController => "BR/EDR C",
Flags::BrEdrHost => "BR/EDR H",
}
}
/// The human-readable name of the flag.
///
/// See [Flags::short_name] for a shorter string for use if compactness is important.
pub fn name(&self) -> &'static str {
match self {
Flags::LeLimited => "LE Limited Discoverable Mode",
Flags::LeDiscoverable => "LE General Discoverable Mode",
Flags::NoBrEdr => "BR/EDR Not Supported",
Flags::BrEdrController => "Simultaneous LE and BR/EDR (Controller)",
Flags::BrEdrHost => "Simultaneous LE and BR/EDR (Host)",
}
}
}

31
rust/src/lib.rs Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Rust API for [Bumble](https://github.com/google/bumble).
//!
//! Bumble is a userspace Bluetooth stack that works with more or less anything that uses HCI. This
//! could be physical Bluetooth USB dongles, netsim, HCI proxied over a network from some device
//! elsewhere, etc.
//!
//! It also does not restrict what you can do with Bluetooth the way that OS Bluetooth APIs
//! typically do, making it good for prototyping, experimentation, test tools, etc.
//!
//! Bumble is primarily written in Python. Rust types that wrap the Python API, which is currently
//! the bulk of the code, are in the [wrapper] module.
#![deny(missing_docs, unsafe_code)]
pub mod wrapper;
pub mod adv;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Assigned numbers from the Bluetooth spec.
mod company_ids;
mod services;
pub use company_ids::COMPANY_IDS;
pub use services::SERVICE_IDS;

View File

@@ -0,0 +1,82 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Assigned service IDs
use crate::wrapper::core::Uuid16;
use lazy_static::lazy_static;
use std::collections;
lazy_static! {
/// Assigned service IDs
pub static ref SERVICE_IDS: collections::HashMap<Uuid16, &'static str> = [
(0x1800_u16, "Generic Access"),
(0x1801, "Generic Attribute"),
(0x1802, "Immediate Alert"),
(0x1803, "Link Loss"),
(0x1804, "TX Power"),
(0x1805, "Current Time"),
(0x1806, "Reference Time Update"),
(0x1807, "Next DST Change"),
(0x1808, "Glucose"),
(0x1809, "Health Thermometer"),
(0x180A, "Device Information"),
(0x180D, "Heart Rate"),
(0x180E, "Phone Alert Status"),
(0x180F, "Battery"),
(0x1810, "Blood Pressure"),
(0x1811, "Alert Notification"),
(0x1812, "Human Interface Device"),
(0x1813, "Scan Parameters"),
(0x1814, "Running Speed and Cadence"),
(0x1815, "Automation IO"),
(0x1816, "Cycling Speed and Cadence"),
(0x1818, "Cycling Power"),
(0x1819, "Location and Navigation"),
(0x181A, "Environmental Sensing"),
(0x181B, "Body Composition"),
(0x181C, "User Data"),
(0x181D, "Weight Scale"),
(0x181E, "Bond Management"),
(0x181F, "Continuous Glucose Monitoring"),
(0x1820, "Internet Protocol Support"),
(0x1821, "Indoor Positioning"),
(0x1822, "Pulse Oximeter"),
(0x1823, "HTTP Proxy"),
(0x1824, "Transport Discovery"),
(0x1825, "Object Transfer"),
(0x1826, "Fitness Machine"),
(0x1827, "Mesh Provisioning"),
(0x1828, "Mesh Proxy"),
(0x1829, "Reconnection Configuration"),
(0x183A, "Insulin Delivery"),
(0x183B, "Binary Sensor"),
(0x183C, "Emergency Configuration"),
(0x183E, "Physical Activity Monitor"),
(0x1843, "Audio Input Control"),
(0x1844, "Volume Control"),
(0x1845, "Volume Offset Control"),
(0x1846, "Coordinated Set Identification Service"),
(0x1847, "Device Time"),
(0x1848, "Media Control Service"),
(0x1849, "Generic Media Control Service"),
(0x184A, "Constant Tone Extension"),
(0x184B, "Telephone Bearer Service"),
(0x184C, "Generic Telephone Bearer Service"),
(0x184D, "Microphone Control"),
]
.into_iter()
.map(|(num, name)| (Uuid16::from_le_bytes(num.to_le_bytes()), name))
.collect();
}

196
rust/src/wrapper/core.rs Normal file
View File

@@ -0,0 +1,196 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Core types
use crate::adv::CommonDataTypeCode;
use lazy_static::lazy_static;
use nom::{bytes, combinator};
use pyo3::{intern, PyObject, PyResult, Python};
use std::fmt;
lazy_static! {
static ref BASE_UUID: [u8; 16] = hex::decode("0000000000001000800000805F9B34FB")
.unwrap()
.try_into()
.unwrap();
}
/// A type code and data pair from an advertisement
pub type AdvertisementDataUnit = (CommonDataTypeCode, Vec<u8>);
/// Contents of an advertisement
pub struct AdvertisingData(pub(crate) PyObject);
impl AdvertisingData {
/// Data units in the advertisement contents
pub fn data_units(&self) -> PyResult<Vec<AdvertisementDataUnit>> {
Python::with_gil(|py| {
let list = self.0.getattr(py, intern!(py, "ad_structures"))?;
list.as_ref(py)
.iter()?
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|tuple| {
let type_code = tuple
.call_method1(intern!(py, "__getitem__"), (0,))?
.extract::<u8>()?
.into();
let data = tuple
.call_method1(intern!(py, "__getitem__"), (1,))?
.extract::<Vec<u8>>()?;
Ok((type_code, data))
})
.collect::<Result<Vec<_>, _>>()
})
}
}
/// 16-bit UUID
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct Uuid16 {
/// Big-endian bytes
uuid: [u8; 2],
}
impl Uuid16 {
/// Construct a UUID from little-endian bytes
pub fn from_le_bytes(mut bytes: [u8; 2]) -> Self {
bytes.reverse();
Self::from_be_bytes(bytes)
}
/// Construct a UUID from big-endian bytes
pub fn from_be_bytes(bytes: [u8; 2]) -> Self {
Self { uuid: bytes }
}
/// The UUID in big-endian bytes form
pub fn as_be_bytes(&self) -> [u8; 2] {
self.uuid
}
/// The UUID in little-endian bytes form
pub fn as_le_bytes(&self) -> [u8; 2] {
let mut uuid = self.uuid;
uuid.reverse();
uuid
}
pub(crate) fn parse_le(input: &[u8]) -> nom::IResult<&[u8], Self> {
combinator::map_res(bytes::complete::take(2_usize), |b: &[u8]| {
b.try_into().map(|mut uuid: [u8; 2]| {
uuid.reverse();
Self { uuid }
})
})(input)
}
}
impl fmt::Debug for Uuid16 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "UUID-16:{}", hex::encode_upper(self.uuid))
}
}
/// 32-bit UUID
#[derive(PartialEq, Eq, Hash)]
pub struct Uuid32 {
/// Big-endian bytes
uuid: [u8; 4],
}
impl Uuid32 {
/// The UUID in big-endian bytes form
pub fn as_bytes(&self) -> [u8; 4] {
self.uuid
}
pub(crate) fn parse(input: &[u8]) -> nom::IResult<&[u8], Self> {
combinator::map_res(bytes::complete::take(4_usize), |b: &[u8]| {
b.try_into().map(|mut uuid: [u8; 4]| {
uuid.reverse();
Self { uuid }
})
})(input)
}
}
impl fmt::Debug for Uuid32 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "UUID-32:{}", hex::encode_upper(self.uuid))
}
}
impl From<Uuid16> for Uuid32 {
fn from(value: Uuid16) -> Self {
let mut uuid = [0; 4];
uuid[2..].copy_from_slice(&value.uuid);
Self { uuid }
}
}
/// 128-bit UUID
#[derive(PartialEq, Eq, Hash)]
pub struct Uuid128 {
/// Big-endian bytes
uuid: [u8; 16],
}
impl Uuid128 {
/// The UUID in big-endian bytes form
pub fn as_bytes(&self) -> [u8; 16] {
self.uuid
}
pub(crate) fn parse_le(input: &[u8]) -> nom::IResult<&[u8], Self> {
combinator::map_res(bytes::complete::take(16_usize), |b: &[u8]| {
b.try_into().map(|mut uuid: [u8; 16]| {
uuid.reverse();
Self { uuid }
})
})(input)
}
}
impl fmt::Debug for Uuid128 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}-{}-{}-{}-{}",
hex::encode_upper(&self.uuid[..4]),
hex::encode_upper(&self.uuid[4..6]),
hex::encode_upper(&self.uuid[6..8]),
hex::encode_upper(&self.uuid[8..10]),
hex::encode_upper(&self.uuid[10..])
)
}
}
impl From<Uuid16> for Uuid128 {
fn from(value: Uuid16) -> Self {
let mut uuid = *BASE_UUID;
uuid[2..4].copy_from_slice(&value.uuid);
Self { uuid }
}
}
impl From<Uuid32> for Uuid128 {
fn from(value: Uuid32) -> Self {
let mut uuid = *BASE_UUID;
uuid[..4].copy_from_slice(&value.uuid);
Self { uuid }
}
}

248
rust/src/wrapper/device.rs Normal file
View File

@@ -0,0 +1,248 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Devices and connections to them
use crate::{
adv::AdvertisementDataBuilder,
wrapper::{
core::AdvertisingData,
gatt_client::{ProfileServiceProxy, ServiceProxy},
hci::Address,
transport::{Sink, Source},
ClosureCallback,
},
};
use pyo3::types::PyDict;
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python, ToPyObject};
use std::path;
/// A device that can send/receive HCI frames.
#[derive(Clone)]
pub struct Device(PyObject);
impl Device {
/// Create a Device per the provided file configured to communicate with a controller through an HCI source/sink
pub fn from_config_file_with_hci(
device_config: &path::Path,
source: Source,
sink: Sink,
) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call_method1(
intern!(py, "from_config_file_with_hci"),
(device_config, source.0, sink.0),
)
.map(|any| Self(any.into()))
})
}
/// Create a Device configured to communicate with a controller through an HCI source/sink
pub fn with_hci(name: &str, address: &str, source: Source, sink: Sink) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call_method1(intern!(py, "with_hci"), (name, address, source.0, sink.0))
.map(|any| Self(any.into()))
})
}
/// Turn the device on
pub async fn power_on(&self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "power_on"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Connect to a peer
pub async fn connect(&self, peer_addr: &str) -> PyResult<Connection> {
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "connect"), (peer_addr,))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(Connection)
}
/// Start scanning
pub async fn start_scanning(&self, filter_duplicates: bool) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("filter_duplicates", filter_duplicates)?;
self.0
.call_method(py, intern!(py, "start_scanning"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Register a callback to be called for each advertisement
pub fn on_advertisement(
&mut self,
callback: impl Fn(Python, Advertisement) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
callback(py, Advertisement(args.get_item(0)?.into()))
});
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "add_listener"), ("advertisement", boxed))
})
.map(|_| ())
}
/// Set the advertisement data to be used when [Device::start_advertising] is called.
pub fn set_advertising_data(&mut self, adv_data: AdvertisementDataBuilder) -> PyResult<()> {
Python::with_gil(|py| {
self.0.setattr(
py,
intern!(py, "advertising_data"),
adv_data.into_bytes().as_slice(),
)
})
.map(|_| ())
}
/// Start advertising the data set with [Device.set_advertisement].
pub async fn start_advertising(&mut self, auto_restart: bool) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("auto_restart", auto_restart)?;
self.0
.call_method(py, intern!(py, "start_advertising"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Stop advertising.
pub async fn stop_advertising(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "stop_advertising"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
}
/// A connection to a remote device.
pub struct Connection(PyObject);
/// The other end of a connection
pub struct Peer(PyObject);
impl Peer {
/// Wrap a [Connection] in a Peer
pub fn new(conn: Connection) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Peer"))?
.call1((conn.0,))
.map(|obj| Self(obj.into()))
})
}
/// Populates the peer's cache of services.
///
/// Returns the discovered services.
pub async fn discover_services(&mut self) -> PyResult<Vec<ServiceProxy>> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "discover_services"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.and_then(|list| {
Python::with_gil(|py| {
list.as_ref(py)
.iter()?
.map(|r| r.map(|h| ServiceProxy(h.to_object(py))))
.collect()
})
})
}
/// Returns a snapshot of the Services currently in the peer's cache
pub fn services(&self) -> PyResult<Vec<ServiceProxy>> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "services"))?
.as_ref(py)
.iter()?
.map(|r| r.map(|h| ServiceProxy(h.to_object(py))))
.collect()
})
}
/// Build a [ProfileServiceProxy] for the specified type.
/// [Peer::discover_services] or some other means of populating the Peer's service cache must be
/// called first, or the required service won't be found.
pub fn create_service_proxy<P: ProfileServiceProxy>(&self) -> PyResult<Option<P>> {
Python::with_gil(|py| {
let module = py.import(P::PROXY_CLASS_MODULE)?;
let class = module.getattr(P::PROXY_CLASS_NAME)?;
self.0
.call_method1(py, intern!(py, "create_service_proxy"), (class,))
.map(|obj| {
if obj.is_none(py) {
None
} else {
Some(P::wrap(obj))
}
})
})
}
}
/// A BLE advertisement
pub struct Advertisement(PyObject);
impl Advertisement {
/// Address that sent the advertisement
pub fn address(&self) -> PyResult<Address> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "address")).map(Address))
}
/// Returns true if the advertisement is connectable
pub fn is_connectable(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "is_connectable"))?
.extract::<bool>(py)
})
}
/// RSSI of the advertisement
pub fn rssi(&self) -> PyResult<i8> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "rssi"))?.extract::<i8>(py))
}
/// Data in the advertisement
pub fn data(&self) -> PyResult<AdvertisingData> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "data")).map(AdvertisingData))
}
}

View File

@@ -0,0 +1,79 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! GATT client support
use crate::wrapper::ClosureCallback;
use pyo3::types::PyTuple;
use pyo3::{intern, PyObject, PyResult, Python};
/// A GATT service on a remote device
pub struct ServiceProxy(pub(crate) PyObject);
impl ServiceProxy {
/// Discover the characteristics in this service.
///
/// Populates an internal cache of characteristics in this service.
pub async fn discover_characteristics(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "discover_characteristics"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
}
/// A GATT characteristic on a remote device
pub struct CharacteristicProxy(pub(crate) PyObject);
impl CharacteristicProxy {
/// Subscribe to changes to the characteristic, executing `callback` for each new value
pub async fn subscribe(
&mut self,
callback: impl Fn(Python, &PyTuple) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| callback(py, args));
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "subscribe"), (boxed,))
.and_then(|obj| pyo3_asyncio::tokio::into_future(obj.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Read the current value of the characteristic
pub async fn read_value(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "read_value"))
.and_then(|obj| pyo3_asyncio::tokio::into_future(obj.as_ref(py)))
})?
.await
}
}
/// Equivalent to the Python `ProfileServiceProxy`.
pub trait ProfileServiceProxy {
/// The module containing the proxy class
const PROXY_CLASS_MODULE: &'static str;
/// The module class name
const PROXY_CLASS_NAME: &'static str;
/// Wrap a PyObject in the Rust wrapper type
fn wrap(obj: PyObject) -> Self;
}

112
rust/src/wrapper/hci.rs Normal file
View File

@@ -0,0 +1,112 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HCI
use itertools::Itertools as _;
use pyo3::{exceptions::PyException, intern, types::PyModule, PyErr, PyObject, PyResult, Python};
/// A Bluetooth address
pub struct Address(pub(crate) PyObject);
impl Address {
/// The type of address
pub fn address_type(&self) -> PyResult<AddressType> {
Python::with_gil(|py| {
let addr_type = self
.0
.getattr(py, intern!(py, "address_type"))?
.extract::<u32>(py)?;
let module = PyModule::import(py, intern!(py, "bumble.hci"))?;
let klass = module.getattr(intern!(py, "Address"))?;
if addr_type
== klass
.getattr(intern!(py, "PUBLIC_DEVICE_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::PublicDevice)
} else if addr_type
== klass
.getattr(intern!(py, "RANDOM_DEVICE_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::RandomDevice)
} else if addr_type
== klass
.getattr(intern!(py, "PUBLIC_IDENTITY_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::PublicIdentity)
} else if addr_type
== klass
.getattr(intern!(py, "RANDOM_IDENTITY_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::RandomIdentity)
} else {
Err(PyErr::new::<PyException, _>("Invalid address type"))
}
})
}
/// True if the address is static
pub fn is_static(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "is_static"))?
.extract::<bool>(py)
})
}
/// True if the address is resolvable
pub fn is_resolvable(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "is_resolvable"))?
.extract::<bool>(py)
})
}
/// Address bytes in _little-endian_ format
pub fn as_le_bytes(&self) -> PyResult<Vec<u8>> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "to_bytes"))?
.extract::<Vec<u8>>(py)
})
}
/// Address bytes as big-endian colon-separated hex
pub fn as_hex(&self) -> PyResult<String> {
self.as_le_bytes().map(|bytes| {
bytes
.into_iter()
.rev()
.map(|byte| hex::encode_upper([byte]))
.join(":")
})
}
}
/// BT address types
#[allow(missing_docs)]
#[derive(PartialEq, Eq, Debug)]
pub enum AddressType {
PublicDevice,
RandomDevice,
PublicIdentity,
RandomIdentity,
}

View File

@@ -0,0 +1,27 @@
//! Bumble & Python logging
use pyo3::types::PyDict;
use pyo3::{intern, types::PyModule, PyResult, Python};
use std::env;
/// Returns the uppercased contents of the `BUMBLE_LOGLEVEL` env var, or `default` if it is not present or not UTF-8.
///
/// The result could be passed to [py_logging_basic_config] to configure Python's logging
/// accordingly.
pub fn bumble_env_logging_level(default: impl Into<String>) -> String {
env::var("BUMBLE_LOGLEVEL")
.unwrap_or_else(|_| default.into())
.to_ascii_uppercase()
}
/// Call `logging.basicConfig` with the provided logging level
pub fn py_logging_basic_config(log_level: impl Into<String>) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("level", log_level.into())?;
PyModule::import(py, intern!(py, "logging"))?
.call_method(intern!(py, "basicConfig"), (), Some(kwargs))
.map(|_| ())
})
}

92
rust/src/wrapper/mod.rs Normal file
View File

@@ -0,0 +1,92 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Types that wrap the Python API.
//!
//! Because mutability, aliasing, etc is all hidden behind Python, the normal Rust rules about
//! only one mutable reference to one piece of memory, etc, may not hold since using `&mut self`
//! instead of `&self` is only guided by inspection of the Python source, not the compiler.
//!
//! The modules are generally structured to mirror the Python equivalents.
// Re-exported to make it easy for users to depend on the same `PyObject`, etc
pub use pyo3;
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
};
pub use pyo3_asyncio;
pub mod assigned_numbers;
pub mod core;
pub mod device;
pub mod gatt_client;
pub mod hci;
pub mod logging;
pub mod profile;
pub mod transport;
/// Convenience extensions to [PyObject]
pub trait PyObjectExt {
/// Get a GIL-bound reference
fn gil_ref<'py>(&'py self, py: Python<'py>) -> &'py PyAny;
/// Extract any [FromPyObject] implementation from this value
fn extract_with_gil<T>(&self) -> PyResult<T>
where
T: for<'a> FromPyObject<'a>,
{
Python::with_gil(|py| self.gil_ref(py).extract::<T>())
}
}
impl PyObjectExt for PyObject {
fn gil_ref<'py>(&'py self, py: Python<'py>) -> &'py PyAny {
self.as_ref(py)
}
}
/// Wrapper to make Rust closures ([Fn] implementations) callable from Python.
///
/// The Python callable form returns a Python `None`.
#[pyclass(name = "SubscribeCallback")]
pub(crate) struct ClosureCallback {
// can't use generics in a pyclass, so have to box
#[allow(clippy::type_complexity)]
callback: Box<dyn Fn(Python, &PyTuple, Option<&PyDict>) -> PyResult<()> + Send + 'static>,
}
impl ClosureCallback {
/// Create a new callback around the provided closure
pub fn new(
callback: impl Fn(Python, &PyTuple, Option<&PyDict>) -> PyResult<()> + Send + 'static,
) -> Self {
Self {
callback: Box::new(callback),
}
}
}
#[pymethods]
impl ClosureCallback {
#[pyo3(signature = (*args, **kwargs))]
fn __call__(
&self,
py: Python<'_>,
args: &PyTuple,
kwargs: Option<&PyDict>,
) -> PyResult<Py<PyAny>> {
(self.callback)(py, args, kwargs).map(|_| py.None())
}
}

View File

@@ -0,0 +1,47 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! GATT profiles
use crate::wrapper::gatt_client::{CharacteristicProxy, ProfileServiceProxy};
use pyo3::{intern, PyObject, PyResult, Python};
/// Exposes the battery GATT service
pub struct BatteryServiceProxy(PyObject);
impl BatteryServiceProxy {
/// Get the battery level, if available
pub fn battery_level(&self) -> PyResult<Option<CharacteristicProxy>> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "battery_level"))
.map(|level| {
if level.is_none(py) {
None
} else {
Some(CharacteristicProxy(level))
}
})
})
}
}
impl ProfileServiceProxy for BatteryServiceProxy {
const PROXY_CLASS_MODULE: &'static str = "bumble.profiles.battery_service";
const PROXY_CLASS_NAME: &'static str = "BatteryServiceProxy";
fn wrap(obj: PyObject) -> Self {
Self(obj)
}
}

View File

@@ -0,0 +1,72 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HCI packet transport
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python};
/// A source/sink pair for HCI packet I/O.
///
/// See <https://google.github.io/bumble/transports/index.html>.
pub struct Transport(PyObject);
impl Transport {
/// Open a new Transport for the provided spec, e.g. `"usb:0"` or `"android-netsim"`.
pub async fn open(transport_spec: impl Into<String>) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.transport"))?
.call_method1(intern!(py, "open_transport"), (transport_spec.into(),))
.and_then(pyo3_asyncio::tokio::into_future)
})?
.await
.map(Self)
}
/// Close the transport.
pub async fn close(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "close"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Returns the source half of the transport.
pub fn source(&self) -> PyResult<Source> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "source"))).map(Source)
}
/// Returns the sink half of the transport.
pub fn sink(&self) -> PyResult<Sink> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "sink"))).map(Sink)
}
}
impl Drop for Transport {
fn drop(&mut self) {
// can't await in a Drop impl, but we can at least spawn a task to do it
let obj = self.0.clone();
tokio::spawn(async move { Self(obj).close().await });
}
}
/// The source side of a [Transport].
#[derive(Clone)]
pub struct Source(pub(crate) PyObject);
/// The sink side of a [Transport].
#[derive(Clone)]
pub struct Sink(pub(crate) PyObject);

View File

@@ -0,0 +1,97 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This tool generates Rust code with assigned number tables from the equivalent Python.
use pyo3::{
intern,
types::{PyDict, PyModule},
PyResult, Python,
};
use std::{collections, env, fs, path};
fn main() -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python();
let mut dir = path::Path::new(&env::var("CARGO_MANIFEST_DIR")?).to_path_buf();
dir.push("src/wrapper/assigned_numbers");
company_ids(&dir)?;
Ok(())
}
fn company_ids(base_dir: &path::Path) -> anyhow::Result<()> {
let mut sorted_ids = load_company_ids()?.into_iter().collect::<Vec<_>>();
sorted_ids.sort_by_key(|(id, _name)| *id);
let mut contents = String::new();
contents.push_str(LICENSE_HEADER);
contents.push_str("\n\n");
contents.push_str(
"// auto-generated by gen_assigned_numbers, do not edit
use crate::wrapper::core::Uuid16;
use lazy_static::lazy_static;
use std::collections;
lazy_static! {
/// Assigned company IDs
pub static ref COMPANY_IDS: collections::HashMap<Uuid16, &'static str> = [
",
);
for (id, name) in sorted_ids {
contents.push_str(&format!(" ({id}_u16, r#\"{name}\"#),\n"))
}
contents.push_str(
" ]
.into_iter()
.map(|(id, name)| (Uuid16::from_be_bytes(id.to_be_bytes()), name))
.collect();
}
",
);
let mut company_ids = base_dir.to_path_buf();
company_ids.push("company_ids.rs");
fs::write(&company_ids, contents)?;
Ok(())
}
fn load_company_ids() -> PyResult<collections::HashMap<u16, String>> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| Ok((k.extract::<u16>()?, v.str()?.to_str()?.to_string())))
.collect::<PyResult<collections::HashMap<_, _>>>()
})
}
const LICENSE_HEADER: &str = r#"// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License."#;

View File

@@ -46,6 +46,7 @@ from bumble.hci import (
HCI_LE_Set_Advertising_Parameters_Command,
HCI_LE_Set_Default_PHY_Command,
HCI_LE_Set_Event_Mask_Command,
HCI_LE_Set_Extended_Advertising_Enable_Command,
HCI_LE_Set_Extended_Scan_Parameters_Command,
HCI_LE_Set_Random_Address_Command,
HCI_LE_Set_Scan_Enable_Command,
@@ -422,6 +423,25 @@ def test_HCI_LE_Set_Extended_Scan_Parameters_Command():
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Extended_Advertising_Enable_Command():
command = HCI_Packet.from_bytes(
bytes.fromhex('0139200e010301050008020600090307000a')
)
assert command.enable == 1
assert command.advertising_handles == [1, 2, 3]
assert command.durations == [5, 6, 7]
assert command.max_extended_advertising_events == [8, 9, 10]
command = HCI_LE_Set_Extended_Advertising_Enable_Command(
enable=1,
advertising_handles=[1, 2, 3],
durations=[5, 6, 7],
max_extended_advertising_events=[8, 9, 10],
)
basic_check(command)
# -----------------------------------------------------------------------------
def test_address():
a = Address('C4:F2:17:1A:1D:BB')
@@ -478,6 +498,7 @@ def run_test_commands():
test_HCI_LE_Read_Remote_Features_Command()
test_HCI_LE_Set_Default_PHY_Command()
test_HCI_LE_Set_Extended_Scan_Parameters_Command()
test_HCI_LE_Set_Extended_Advertising_Enable_Command()
# -----------------------------------------------------------------------------

View File

@@ -68,13 +68,16 @@ class TwoDevices:
),
]
self.paired = [None, None]
self.paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which, keys):
self.paired[which] = keys
def on_paired(self, which: int, keys: PairingKeys):
self.paired[which].set_result(keys)
# -----------------------------------------------------------------------------
@@ -323,8 +326,8 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2):
# Pair
await two_devices.devices[0].pair(connection)
assert connection.is_encrypted
assert two_devices.paired[0] is not None
assert two_devices.paired[1] is not None
assert await two_devices.paired[0] is not None
assert await two_devices.paired[1] is not None
# -----------------------------------------------------------------------------
@@ -527,16 +530,12 @@ async def test_self_smp_over_classic():
two_devices.connections[0].encryption = 1
two_devices.connections[1].encryption = 1
paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]
def on_pairing(which: int, keys: PairingKeys):
paired[which].set_result(keys)
two_devices.connections[0].on('pairing', lambda keys: on_pairing(0, keys))
two_devices.connections[1].on('pairing', lambda keys: on_pairing(1, keys))
two_devices.connections[0].on(
'pairing', lambda keys: two_devices.on_paired(0, keys)
)
two_devices.connections[1].on(
'pairing', lambda keys: two_devices.on_paired(1, keys)
)
# Mock SMP
with patch('bumble.smp.Session', spec=True) as MockSmpSession:
@@ -547,7 +546,7 @@ async def test_self_smp_over_classic():
# Start CTKD
await two_devices.connections[0].pair()
await asyncio.gather(*paired)
await asyncio.gather(*two_devices.paired)
# Phase 2 commands should not be invoked
MockSmpSession.send_pairing_confirm_command.assert_not_called()
@@ -556,6 +555,26 @@ async def test_self_smp_over_classic():
MockSmpSession.send_pairing_random_command.assert_not_called()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_smp_public_address():
pairing_config = PairingConfig(
mitm=True,
sc=True,
bonding=True,
identity_address_type=PairingConfig.AddressType.PUBLIC,
delegate=PairingDelegate(
PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
PairingDelegate.KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
| PairingDelegate.KeyDistribution.DISTRIBUTE_IDENTITY_KEY
| PairingDelegate.KeyDistribution.DISTRIBUTE_SIGNING_KEY
| PairingDelegate.KeyDistribution.DISTRIBUTE_LINK_KEY,
),
)
await _test_self_smp_with_configs(pairing_config, pairing_config)
# -----------------------------------------------------------------------------
async def run_test_self():
await test_self_connection()
@@ -565,6 +584,7 @@ async def run_test_self():
await test_self_smp_reject()
await test_self_smp_wrong_pin()
await test_self_smp_over_classic()
await test_self_smp_public_address()
# -----------------------------------------------------------------------------