Merge branch 'google:main' into bumble_hid_device

This commit is contained in:
Snehal Karnataki
2023-11-30 04:50:05 +00:00
committed by GitHub
11 changed files with 429 additions and 95 deletions

View File

@@ -21,6 +21,7 @@
"cccds", "cccds",
"cmac", "cmac",
"CONNECTIONLESS", "CONNECTIONLESS",
"csip",
"csrcs", "csrcs",
"datagram", "datagram",
"DATALINK", "DATALINK",
@@ -45,6 +46,7 @@
"NONCONN", "NONCONN",
"OXIMETER", "OXIMETER",
"popleft", "popleft",
"PRAND",
"protobuf", "protobuf",
"psms", "psms",
"pyee", "pyee",
@@ -56,6 +58,7 @@
"SEID", "SEID",
"seids", "seids",
"SERV", "SERV",
"SIRK",
"ssrc", "ssrc",
"strerror", "strerror",
"subband", "subband",

View File

@@ -23,6 +23,7 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Iterable
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@@ -32,6 +33,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Type, Type,
TypeVar,
Set, Set,
Union, Union,
cast, cast,
@@ -440,8 +442,11 @@ class LePhyOptions:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_PROXY_CLASS = TypeVar('_PROXY_CLASS', bound=gatt_client.ProfileServiceProxy)
class Peer: class Peer:
def __init__(self, connection): def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
# Create a GATT client for the connection # Create a GATT client for the connection
@@ -449,77 +454,113 @@ class Peer:
connection.gatt_client = self.gatt_client connection.gatt_client = self.gatt_client
@property @property
def services(self): def services(self) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.services return self.gatt_client.services
async def request_mtu(self, mtu): async def request_mtu(self, mtu: int) -> int:
mtu = await self.gatt_client.request_mtu(mtu) mtu = await self.gatt_client.request_mtu(mtu)
self.connection.emit('connection_att_mtu_update') self.connection.emit('connection_att_mtu_update')
return mtu return mtu
async def discover_service(self, uuid): async def discover_service(
self, uuid: Union[core.UUID, str]
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_service(uuid) return await self.gatt_client.discover_service(uuid)
async def discover_services(self, uuids=()): async def discover_services(
self, uuids: Iterable[core.UUID] = ()
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_services(uuids) return await self.gatt_client.discover_services(uuids)
async def discover_included_services(self, service): async def discover_included_services(
self, service: gatt_client.ServiceProxy
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_included_services(service) return await self.gatt_client.discover_included_services(service)
async def discover_characteristics(self, uuids=(), service=None): async def discover_characteristics(
self,
uuids: Iterable[Union[core.UUID, str]] = (),
service: Optional[gatt_client.ServiceProxy] = None,
) -> List[gatt_client.CharacteristicProxy]:
return await self.gatt_client.discover_characteristics( return await self.gatt_client.discover_characteristics(
uuids=uuids, service=service uuids=uuids, service=service
) )
async def discover_descriptors( async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None self,
characteristic: Optional[gatt_client.CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
): ):
return await self.gatt_client.discover_descriptors( return await self.gatt_client.discover_descriptors(
characteristic, start_handle, end_handle characteristic, start_handle, end_handle
) )
async def discover_attributes(self): async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes() return await self.gatt_client.discover_attributes()
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): async def subscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
return await self.gatt_client.subscribe( return await self.gatt_client.subscribe(
characteristic, subscriber, prefer_notify characteristic, subscriber, prefer_notify
) )
async def unsubscribe(self, characteristic, subscriber=None): async def unsubscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
return await self.gatt_client.unsubscribe(characteristic, subscriber) return await self.gatt_client.unsubscribe(characteristic, subscriber)
async def read_value(self, attribute): async def read_value(
self, attribute: Union[int, gatt_client.AttributeProxy]
) -> bytes:
return await self.gatt_client.read_value(attribute) return await self.gatt_client.read_value(attribute)
async def write_value(self, attribute, value, with_response=False): async def write_value(
self,
attribute: Union[int, gatt_client.AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
return await self.gatt_client.write_value(attribute, value, with_response) return await self.gatt_client.write_value(attribute, value, with_response)
async def read_characteristics_by_uuid(self, uuid, service=None): async def read_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[bytes]:
return await self.gatt_client.read_characteristics_by_uuid(uuid, service) return await self.gatt_client.read_characteristics_by_uuid(uuid, service)
def get_services_by_uuid(self, uuid): def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.get_services_by_uuid(uuid) return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid(self, uuid, service=None): def get_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[gatt_client.CharacteristicProxy]:
return self.gatt_client.get_characteristics_by_uuid(uuid, service) return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(self, proxy_class): def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
return proxy_class.from_client(self.gatt_client) return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))
async def discover_service_and_create_proxy(self, proxy_class): async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
# Discover the first matching service and its characteristics # Discover the first matching service and its characteristics
services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID) services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID)
if services: if services:
service = services[0] service = services[0]
await service.discover_characteristics() await service.discover_characteristics()
return self.create_service_proxy(proxy_class) return self.create_service_proxy(proxy_class)
return None
async def sustain(self, timeout=None): async def sustain(self, timeout: Optional[float] = None) -> None:
await self.connection.sustain(timeout) await self.connection.sustain(timeout)
# [Classic only] # [Classic only]
async def request_name(self): async def request_name(self) -> str:
return await self.connection.request_remote_name() return await self.connection.request_remote_name()
async def __aenter__(self): async def __aenter__(self):
@@ -532,7 +573,7 @@ class Peer:
async def __aexit__(self, exc_type, exc_value, traceback): async def __aexit__(self, exc_type, exc_value, traceback):
pass pass
def __str__(self): def __str__(self) -> str:
return f'{self.connection.peer_address} as {self.connection.role_name}' return f'{self.connection.peer_address} as {self.connection.role_name}'
@@ -732,7 +773,7 @@ class Connection(CompositeEventEmitter):
async def switch_role(self, role: int) -> None: async def switch_role(self, role: int) -> None:
return await self.device.switch_role(self, role) return await self.device.switch_role(self, role)
async def sustain(self, timeout=None): async def sustain(self, timeout: Optional[float] = None) -> None:
"""Idles the current task waiting for a disconnect or timeout""" """Idles the current task waiting for a disconnect or timeout"""
abort = asyncio.get_running_loop().create_future() abort = asyncio.get_running_loop().create_future()

View File

@@ -93,30 +93,35 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery') GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor') GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration') GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor') GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control') GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control') GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control') GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service') GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification')
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time') GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
# LE Audio Services GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service') GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service') GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension Service') GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service') GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service') GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control Service') GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control Service') GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan Service') GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities Service') GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement Service') GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement Service') GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio Service') GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access Service') GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio Service') GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement Service') GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
# Types # Attribute Types
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service') GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service') GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include') GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
@@ -139,6 +144,8 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting') GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting') GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data') GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
# Device Information Service # Device Information Service
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID') GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
@@ -166,6 +173,9 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service # Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level') GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# Telephony And Media Audio Service (TMAS)
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
# Audio Input Control Service (AICS) # Audio Input Control Service (AICS)
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State') GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute') GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
@@ -274,6 +284,9 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time') GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report') GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution') GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long

View File

@@ -38,6 +38,7 @@ from typing import (
Any, Any,
Iterable, Iterable,
Type, Type,
Set,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -128,7 +129,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy] included_services: List[ServiceProxy]
@staticmethod @staticmethod
def from_client(service_class, client, service_uuid): def from_client(service_class, client: Client, service_uuid: UUID):
# The service and its characteristics are considered to have already been # The service and its characteristics are considered to have already been
# discovered # discovered
services = client.get_services_by_uuid(service_uuid) services = client.get_services_by_uuid(service_uuid)
@@ -246,8 +247,12 @@ class ProfileServiceProxy:
class Client: class Client:
services: List[ServiceProxy] services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]] cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]] notification_subscribers: Dict[
indication_subscribers: Dict[int, Callable[[bytes], Any]] int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]] pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU] pending_request: Optional[ATT_PDU]
@@ -682,8 +687,8 @@ class Client:
async def discover_descriptors( async def discover_descriptors(
self, self,
characteristic: Optional[CharacteristicProxy] = None, characteristic: Optional[CharacteristicProxy] = None,
start_handle=None, start_handle: Optional[int] = None,
end_handle=None, end_handle: Optional[int] = None,
) -> List[DescriptorProxy]: ) -> List[DescriptorProxy]:
''' '''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -789,7 +794,12 @@ class Client:
return attributes return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True): async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
# do it now # do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
@@ -833,7 +843,11 @@ class Client:
await self.write_value(cccd, struct.pack('<H', bits), with_response=True) await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None): async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
# do it now # do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
@@ -853,7 +867,7 @@ class Client:
self.notification_subscribers, self.notification_subscribers,
self.indication_subscribers, self.indication_subscribers,
): ):
subscribers = subscriber_set.get(characteristic.handle, []) subscribers = subscriber_set.get(characteristic.handle, set())
if subscriber in subscribers: if subscriber in subscribers:
subscribers.remove(subscriber) subscribers.remove(subscriber)
@@ -871,7 +885,7 @@ class Client:
async def read_value( async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any: ) -> bytes:
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -1067,7 +1081,7 @@ class Client:
def on_att_handle_value_notification(self, notification): def on_att_handle_value_notification(self, notification):
# Call all subscribers # Call all subscribers
subscribers = self.notification_subscribers.get( subscribers = self.notification_subscribers.get(
notification.attribute_handle, [] notification.attribute_handle, set()
) )
if not subscribers: if not subscribers:
logger.warning('!!! received notification with no subscriber') logger.warning('!!! received notification with no subscriber')
@@ -1081,7 +1095,9 @@ class Client:
def on_att_handle_value_indication(self, indication): def on_att_handle_value_indication(self, indication):
# Call all subscribers # Call all subscribers
subscribers = self.indication_subscribers.get(indication.attribute_handle, []) subscribers = self.indication_subscribers.get(
indication.attribute_handle, set()
)
if not subscribers: if not subscribers:
logger.warning('!!! received indication with no subscriber') logger.warning('!!! received indication with no subscriber')

View File

@@ -5296,6 +5296,10 @@ class HCI_Disconnection_Complete_Event(HCI_Event):
See Bluetooth spec @ 7.7.5 Disconnection Complete Event See Bluetooth spec @ 7.7.5 Disconnection Complete Event
''' '''
status: int
connection_handle: int
reason: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)]) @HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)])

View File

@@ -15,30 +15,39 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from typing import cast, Any
import logging import logging
from .colors import color from bumble import avdtp
from .att import ATT_CID, ATT_PDU from bumble.colors import color
from .smp import SMP_CID, SMP_Command from bumble.att import ATT_CID, ATT_PDU
from .core import name_or_number from bumble.smp import SMP_CID, SMP_Command
from .l2cap import ( from bumble.core import name_or_number
from bumble.l2cap import (
L2CAP_PDU, L2CAP_PDU,
L2CAP_CONNECTION_REQUEST, L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE, L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame, L2CAP_Control_Frame,
L2CAP_Connection_Request,
L2CAP_Connection_Response, L2CAP_Connection_Response,
) )
from .hci import ( from bumble.hci import (
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
) )
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM from bumble.sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -50,23 +59,25 @@ logger = logging.getLogger(__name__)
PSM_NAMES = { PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM', RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP', SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP' avdtp.AVDTP_PSM: 'AVDTP',
# TODO: add more PSM values
} }
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketTracer: class PacketTracer:
class AclStream: class AclStream:
def __init__(self, analyzer): psms: MutableMapping[int, int]
peer: PacketTracer.AclStream
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks # pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
if l2cap_pdu.cid == ATT_CID: if l2cap_pdu.cid == ATT_CID:
@@ -81,26 +92,30 @@ class PacketTracer:
# Check if this signals a new channel # Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm connection_request = cast(L2CAP_Connection_Request, control_frame)
self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame)
if ( if (
control_frame.result connection_response.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
): ):
if self.peer: if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid): if psm := self.peer.psms.get(
connection_response.source_cid
):
# Found a pending connection # Found a pending connection
self.psms[control_frame.destination_cid] = psm self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for # For AVDTP connections, create a packet assembler for
# each direction # each direction
if psm == AVDTP_PSM: if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[ self.avdtp_assemblers[
control_frame.source_cid connection_response.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message) ] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[ self.peer.avdtp_assemblers[
control_frame.destination_cid connection_response.destination_cid
] = AVDTP_MessageAssembler( ] = avdtp.MessageAssembler(
self.peer.on_avdtp_message self.peer.on_avdtp_message
) )
@@ -113,7 +128,7 @@ class PacketTracer:
elif psm == RFCOMM_PSM: elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload) rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame) self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM: elif psm == avdtp.AVDTP_PSM:
self.analyzer.emit( self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}' f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
@@ -130,22 +145,26 @@ class PacketTracer:
else: else:
self.analyzer.emit(l2cap_pdu) self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message): def on_avdtp_message(
self, transaction_label: int, message: avdtp.Message
) -> None:
self.analyzer.emit( self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}' f'{color("AVDTP", "green")} [{transaction_label}] {message}'
) )
def feed_packet(self, packet): def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet) self.packet_assembler.feed_packet(packet)
class Analyzer: class Analyzer:
def __init__(self, label, emit_message): acl_streams: MutableMapping[int, PacketTracer.AclStream]
peer: PacketTracer.Analyzer
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
self.label = label self.label = label
self.emit_message = emit_message self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle): def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info( logger.info(
f'[{self.label}] +++ Creating ACL stream for connection ' f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
@@ -160,7 +179,7 @@ class PacketTracer:
return stream return stream
def end_acl_stream(self, connection_handle): def end_acl_stream(self, connection_handle: int) -> None:
if connection_handle in self.acl_streams: if connection_handle in self.acl_streams:
logger.info( logger.info(
f'[{self.label}] --- Removing ACL stream for connection ' f'[{self.label}] --- Removing ACL stream for connection '
@@ -171,23 +190,29 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well # Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle) self.peer.end_acl_stream(connection_handle)
def on_packet(self, packet): def on_packet(self, packet: HCI_Packet) -> None:
self.emit(packet) self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET: if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the # Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle # first ACL packet for that connection handle
if (stream := self.acl_streams.get(packet.connection_handle)) is None: if (
stream = self.start_acl_stream(packet.connection_handle) stream := self.acl_streams.get(acl_packet.connection_handle)
stream.feed_packet(packet) ) is None:
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET: elif packet.hci_packet_type == HCI_EVENT_PACKET:
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT: event_packet = cast(HCI_Event, packet)
self.end_acl_stream(packet.connection_handle) if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
def emit(self, message): def emit(self, message: Any) -> None:
self.emit_message(f'[{self.label}] {message}') self.emit_message(f'[{self.label}] {message}')
def trace(self, packet, direction=0): def trace(self, packet: HCI_Packet, direction: int = 0) -> None:
if direction == 0: if direction == 0:
self.host_to_controller_analyzer.on_packet(packet) self.host_to_controller_analyzer.on_packet(packet)
else: else:
@@ -195,10 +220,10 @@ class PacketTracer:
def __init__( def __init__(
self, self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'), host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'), controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info, emit_message: Callable[..., None] = logger.info,
): ) -> None:
self.host_to_controller_analyzer = PacketTracer.Analyzer( self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message host_to_controller_label, emit_message
) )

View File

@@ -391,6 +391,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
''' '''
psm: int
source_cid: int
@staticmethod @staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]: def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2 psm_length = 2
@@ -432,6 +435,11 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE
''' '''
source_cid: int
destination_cid: int
status: int
result: int
CONNECTION_SUCCESSFUL = 0x0000 CONNECTION_SUCCESSFUL = 0x0000
CONNECTION_PENDING = 0x0001 CONNECTION_PENDING = 0x0001
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002 CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002

147
bumble/profiles/csip.py Normal file
View File

@@ -0,0 +1,147 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Optional
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
ENCRYPTED = 0x00
PLAINTEXT = 0x01
class MemberLock(enum.IntEnum):
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
UNLOCKED = 0x01
LOCKED = 0x02
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
# TODO: Implement RSI Generator
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
def __init__(
self,
set_identity_resolving_key: bytes,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
characteristics = []
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
# TODO: Implement encrypted SIRK reader.
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
)
characteristics.append(self.set_identity_resolving_key_characteristic)
if coordinated_set_size is not None:
self.coordinated_set_size_characteristic = gatt.Characteristic(
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
if set_member_lock is not None:
self.set_member_lock_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
characteristics.append(self.set_member_lock_characteristic)
if set_member_rank is not None:
self.set_member_rank_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
):
self.coordinated_set_size = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
):
self.set_member_lock = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]

View File

@@ -150,7 +150,7 @@ class PacketParser:
try: try:
self.sink.on_packet(bytes(self.packet)) self.sink.on_packet(bytes(self.packet))
except Exception as error: except Exception as error:
logger.warning( logger.exception(
color(f'!!! Exception in on_packet: {error}', 'red') color(f'!!! Exception in on_packet: {error}', 'red')
) )
self.reset() self.reset()

74
tests/csip_test.py Normal file
View File

@@ -0,0 +1,74 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import pytest
import struct
import logging
from bumble import device
from bumble.profiles import csip
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_csis():
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices()
devices[0].add_service(
csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK,
coordinated_set_size=2,
set_member_lock=csip.MemberLock.UNLOCKED,
set_member_rank=0,
)
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
csis_client = await peer.discover_service_and_create_proxy(
csip.CoordinatedSetIdentificationProxy
)
assert (
await csis_client.set_identity_resolving_key.read_value()
== bytes([csip.SirkType.PLAINTEXT]) + SIRK
)
assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
assert await csis_client.set_member_lock.read_value() == struct.pack(
'B', csip.MemberLock.UNLOCKED
)
assert await csis_client.set_member_rank.read_value() == struct.pack('B', 0)
# -----------------------------------------------------------------------------
async def run():
await test_csis()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -71,3 +71,6 @@ class TwoDevices:
# Check the post conditions # Check the post conditions
assert self.connections[0] is not None assert self.connections[0] is not None
assert self.connections[1] is not None assert self.connections[1] is not None
def __getitem__(self, index: int) -> Device:
return self.devices[index]