Compare commits

...

11 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
320164d476 Merge pull request #355 from google/gbg/fix-gatt-unsubscribe
fix #354 (gatt unsubscribe)
2023-11-29 22:28:57 -08:00
zxzxwu
a9c4c5833d Merge pull request #350 from zxzxwu/csip
Add Coordinated Set Identification Service(CSIS)
2023-11-30 12:15:56 +08:00
Gilles Boccon-Gibod
58c9c4f590 fix #354 2023-11-29 19:19:40 -08:00
zxzxwu
24524d88cb Merge pull request #342 from zxzxwu/typing
Typing helper
2023-11-30 00:21:44 +08:00
zxzxwu
b8849ab311 Merge pull request #349 from zxzxwu/stack
Log track back in on_packet
2023-11-30 00:20:20 +08:00
Josh Wu
f3cd8f8ed0 Typing helper 2023-11-29 21:24:27 +08:00
zxzxwu
2b26de3f3a Merge pull request #348 from zxzxwu/gattc
Typing GATT Client and Device Peer
2023-11-29 15:09:40 +08:00
Josh Wu
0149c4c212 Log track back in on_packet
Many errors are raised in on_packet() callbacks, but currently it only
provides a very brief error message.
2023-11-29 15:01:15 +08:00
Gilles Boccon-Gibod
f2ed898784 Merge pull request #352 from google/gbg/more-gatt-uuids
add a few uuids
2023-11-28 22:44:40 -08:00
Josh Wu
464a476f9f Add CSIP 2023-11-29 14:09:31 +08:00
Josh Wu
04d5bf3afc Typing GATT Client and Device Peer 2023-11-28 21:57:57 +08:00
11 changed files with 506 additions and 86 deletions

View File

@@ -21,6 +21,7 @@
"cccds",
"cmac",
"CONNECTIONLESS",
"csip",
"csrcs",
"datagram",
"DATALINK",
@@ -45,6 +46,7 @@
"NONCONN",
"OXIMETER",
"popleft",
"PRAND",
"protobuf",
"psms",
"pyee",
@@ -56,6 +58,7 @@
"SEID",
"seids",
"SERV",
"SIRK",
"ssrc",
"strerror",
"subband",

View File

@@ -23,6 +23,7 @@ import asyncio
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from collections.abc import Iterable
from typing import (
Any,
Callable,
@@ -32,6 +33,7 @@ from typing import (
Optional,
Tuple,
Type,
TypeVar,
Set,
Union,
cast,
@@ -440,8 +442,11 @@ class LePhyOptions:
# -----------------------------------------------------------------------------
_PROXY_CLASS = TypeVar('_PROXY_CLASS', bound=gatt_client.ProfileServiceProxy)
class Peer:
def __init__(self, connection):
def __init__(self, connection: Connection) -> None:
self.connection = connection
# Create a GATT client for the connection
@@ -449,77 +454,113 @@ class Peer:
connection.gatt_client = self.gatt_client
@property
def services(self):
def services(self) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.services
async def request_mtu(self, mtu):
async def request_mtu(self, mtu: int) -> int:
mtu = await self.gatt_client.request_mtu(mtu)
self.connection.emit('connection_att_mtu_update')
return mtu
async def discover_service(self, uuid):
async def discover_service(
self, uuid: Union[core.UUID, str]
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_service(uuid)
async def discover_services(self, uuids=()):
async def discover_services(
self, uuids: Iterable[core.UUID] = ()
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_services(uuids)
async def discover_included_services(self, service):
async def discover_included_services(
self, service: gatt_client.ServiceProxy
) -> List[gatt_client.ServiceProxy]:
return await self.gatt_client.discover_included_services(service)
async def discover_characteristics(self, uuids=(), service=None):
async def discover_characteristics(
self,
uuids: Iterable[Union[core.UUID, str]] = (),
service: Optional[gatt_client.ServiceProxy] = None,
) -> List[gatt_client.CharacteristicProxy]:
return await self.gatt_client.discover_characteristics(
uuids=uuids, service=service
)
async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None
self,
characteristic: Optional[gatt_client.CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
):
return await self.gatt_client.discover_descriptors(
characteristic, start_handle, end_handle
)
async def discover_attributes(self):
async def discover_attributes(self) -> List[gatt_client.AttributeProxy]:
return await self.gatt_client.discover_attributes()
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
async def subscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
return await self.gatt_client.subscribe(
characteristic, subscriber, prefer_notify
)
async def unsubscribe(self, characteristic, subscriber=None):
async def unsubscribe(
self,
characteristic: gatt_client.CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
) -> None:
return await self.gatt_client.unsubscribe(characteristic, subscriber)
async def read_value(self, attribute):
async def read_value(
self, attribute: Union[int, gatt_client.AttributeProxy]
) -> bytes:
return await self.gatt_client.read_value(attribute)
async def write_value(self, attribute, value, with_response=False):
async def write_value(
self,
attribute: Union[int, gatt_client.AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
return await self.gatt_client.write_value(attribute, value, with_response)
async def read_characteristics_by_uuid(self, uuid, service=None):
async def read_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[bytes]:
return await self.gatt_client.read_characteristics_by_uuid(uuid, service)
def get_services_by_uuid(self, uuid):
def get_services_by_uuid(self, uuid: core.UUID) -> List[gatt_client.ServiceProxy]:
return self.gatt_client.get_services_by_uuid(uuid)
def get_characteristics_by_uuid(self, uuid, service=None):
def get_characteristics_by_uuid(
self, uuid: core.UUID, service: Optional[gatt_client.ServiceProxy] = None
) -> List[gatt_client.CharacteristicProxy]:
return self.gatt_client.get_characteristics_by_uuid(uuid, service)
def create_service_proxy(self, proxy_class):
return proxy_class.from_client(self.gatt_client)
def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS:
return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client))
async def discover_service_and_create_proxy(self, proxy_class):
async def discover_service_and_create_proxy(
self, proxy_class: Type[_PROXY_CLASS]
) -> Optional[_PROXY_CLASS]:
# Discover the first matching service and its characteristics
services = await self.discover_service(proxy_class.SERVICE_CLASS.UUID)
if services:
service = services[0]
await service.discover_characteristics()
return self.create_service_proxy(proxy_class)
return None
async def sustain(self, timeout=None):
async def sustain(self, timeout: Optional[float] = None) -> None:
await self.connection.sustain(timeout)
# [Classic only]
async def request_name(self):
async def request_name(self) -> str:
return await self.connection.request_remote_name()
async def __aenter__(self):
@@ -532,7 +573,7 @@ class Peer:
async def __aexit__(self, exc_type, exc_value, traceback):
pass
def __str__(self):
def __str__(self) -> str:
return f'{self.connection.peer_address} as {self.connection.role_name}'
@@ -732,7 +773,7 @@ class Connection(CompositeEventEmitter):
async def switch_role(self, role: int) -> None:
return await self.device.switch_role(self, role)
async def sustain(self, timeout=None):
async def sustain(self, timeout: Optional[float] = None) -> None:
"""Idles the current task waiting for a disconnect or timeout"""
abort = asyncio.get_running_loop().create_future()

View File

@@ -38,6 +38,7 @@ from typing import (
Any,
Iterable,
Type,
Set,
TYPE_CHECKING,
)
@@ -128,7 +129,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy]
@staticmethod
def from_client(service_class, client, service_uuid):
def from_client(service_class, client: Client, service_uuid: UUID):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
@@ -206,11 +207,11 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None):
async def unsubscribe(self, subscriber=None, force=False):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
return await self.client.unsubscribe(self, subscriber, force)
def __str__(self) -> str:
return (
@@ -246,8 +247,12 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]]
indication_subscribers: Dict[int, Callable[[bytes], Any]]
notification_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
@@ -257,10 +262,8 @@ class Client:
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.notification_subscribers = {} # Subscriber set, by attribute handle
self.indication_subscribers = {} # Subscriber set, by attribute handle
self.services = []
self.cached_values = {}
@@ -682,8 +685,8 @@ class Client:
async def discover_descriptors(
self,
characteristic: Optional[CharacteristicProxy] = None,
start_handle=None,
end_handle=None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
) -> List[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -789,7 +792,12 @@ class Client:
return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -826,6 +834,7 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
@@ -833,7 +842,18 @@ class Client:
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None):
async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
force: bool = False,
) -> None:
'''
Unsubscribe from a characteristic.
If `force` is True, this will write zeros to the CCCD when there are no
subscribers left, even if there were already no registered subscribers.
'''
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -847,31 +867,45 @@ class Client:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
# Check if the characteristic has subscribers
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not force:
return
# Remove the subscriber(s)
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
if (
subscribers := subscriber_set.get(characteristic.handle)
) and subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
# Remove all subscribers for this attribute from the sets
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
if not self.notification_subscribers and not self.indication_subscribers:
# Update the CCCD
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any:
) -> bytes:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -1067,7 +1101,7 @@ class Client:
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
notification.attribute_handle, set()
)
if not subscribers:
logger.warning('!!! received notification with no subscriber')
@@ -1081,7 +1115,9 @@ class Client:
def on_att_handle_value_indication(self, indication):
# Call all subscribers
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
subscribers = self.indication_subscribers.get(
indication.attribute_handle, set()
)
if not subscribers:
logger.warning('!!! received indication with no subscriber')

View File

@@ -5296,6 +5296,10 @@ class HCI_Disconnection_Complete_Event(HCI_Event):
See Bluetooth spec @ 7.7.5 Disconnection Complete Event
'''
status: int
connection_handle: int
reason: int
# -----------------------------------------------------------------------------
@HCI_Event.event([('status', STATUS_SPEC), ('connection_handle', 2)])

View File

@@ -15,30 +15,39 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from typing import cast, Any
import logging
from .colors import color
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number
from .l2cap import (
from bumble import avdtp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
from bumble.core import name_or_number
from bumble.l2cap import (
L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Request,
L2CAP_Connection_Response,
)
from .hci import (
from bumble.hci import (
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
)
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
from bumble.sdp import SDP_PDU, SDP_PSM
# -----------------------------------------------------------------------------
# Logging
@@ -50,23 +59,25 @@ logger = logging.getLogger(__name__)
PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
# TODO: add more PSM values
avdtp.AVDTP_PSM: 'AVDTP',
}
# -----------------------------------------------------------------------------
class PacketTracer:
class AclStream:
def __init__(self, analyzer):
psms: MutableMapping[int, int]
peer: PacketTracer.AclStream
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu):
def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
if l2cap_pdu.cid == ATT_CID:
@@ -81,26 +92,30 @@ class PacketTracer:
# Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm
connection_request = cast(L2CAP_Connection_Request, control_frame)
self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame)
if (
control_frame.result
connection_response.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid):
if psm := self.peer.psms.get(
connection_response.source_cid
):
# Found a pending connection
self.psms[control_frame.destination_cid] = psm
self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM:
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
connection_response.destination_cid
] = avdtp.MessageAssembler(
self.peer.on_avdtp_message
)
@@ -113,7 +128,7 @@ class PacketTracer:
elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM:
elif psm == avdtp.AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
@@ -130,22 +145,26 @@ class PacketTracer:
else:
self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message):
def on_avdtp_message(
self, transaction_label: int, message: avdtp.Message
) -> None:
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def feed_packet(self, packet):
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet)
class Analyzer:
def __init__(self, label, emit_message):
acl_streams: MutableMapping[int, PacketTracer.AclStream]
peer: PacketTracer.Analyzer
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle):
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
@@ -160,7 +179,7 @@ class PacketTracer:
return stream
def end_acl_stream(self, connection_handle):
def end_acl_stream(self, connection_handle: int) -> None:
if connection_handle in self.acl_streams:
logger.info(
f'[{self.label}] --- Removing ACL stream for connection '
@@ -171,23 +190,29 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle)
def on_packet(self, packet):
def on_packet(self, packet: HCI_Packet) -> None:
self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle
if (stream := self.acl_streams.get(packet.connection_handle)) is None:
stream = self.start_acl_stream(packet.connection_handle)
stream.feed_packet(packet)
if (
stream := self.acl_streams.get(acl_packet.connection_handle)
) is None:
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET:
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(packet.connection_handle)
event_packet = cast(HCI_Event, packet)
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
def emit(self, message):
def emit(self, message: Any) -> None:
self.emit_message(f'[{self.label}] {message}')
def trace(self, packet, direction=0):
def trace(self, packet: HCI_Packet, direction: int = 0) -> None:
if direction == 0:
self.host_to_controller_analyzer.on_packet(packet)
else:
@@ -195,10 +220,10 @@ class PacketTracer:
def __init__(
self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info,
):
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
emit_message: Callable[..., None] = logger.info,
) -> None:
self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message
)

View File

@@ -391,6 +391,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
'''
psm: int
source_cid: int
@staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
@@ -432,6 +435,11 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE
'''
source_cid: int
destination_cid: int
status: int
result: int
CONNECTION_SUCCESSFUL = 0x0000
CONNECTION_PENDING = 0x0001
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002

147
bumble/profiles/csip.py Normal file
View File

@@ -0,0 +1,147 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Optional
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
ENCRYPTED = 0x00
PLAINTEXT = 0x01
class MemberLock(enum.IntEnum):
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
UNLOCKED = 0x01
LOCKED = 0x02
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
# TODO: Implement RSI Generator
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
def __init__(
self,
set_identity_resolving_key: bytes,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
characteristics = []
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
# TODO: Implement encrypted SIRK reader.
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
)
characteristics.append(self.set_identity_resolving_key_characteristic)
if coordinated_set_size is not None:
self.coordinated_set_size_characteristic = gatt.Characteristic(
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
if set_member_lock is not None:
self.set_member_lock_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
characteristics.append(self.set_member_lock_characteristic)
if set_member_rank is not None:
self.set_member_rank_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
):
self.coordinated_set_size = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
):
self.set_member_lock = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]

View File

@@ -150,7 +150,7 @@ class PacketParser:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.warning(
logger.exception(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset()

74
tests/csip_test.py Normal file
View File

@@ -0,0 +1,74 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import pytest
import struct
import logging
from bumble import device
from bumble.profiles import csip
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_csis():
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices()
devices[0].add_service(
csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK,
coordinated_set_size=2,
set_member_lock=csip.MemberLock.UNLOCKED,
set_member_rank=0,
)
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
csis_client = await peer.discover_service_and_create_proxy(
csip.CoordinatedSetIdentificationProxy
)
assert (
await csis_client.set_identity_resolving_key.read_value()
== bytes([csip.SirkType.PLAINTEXT]) + SIRK
)
assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
assert await csis_client.set_member_lock.read_value() == struct.pack(
'B', csip.MemberLock.UNLOCKED
)
assert await csis_client.set_member_rank.read_value() == struct.pack('B', 0)
# -----------------------------------------------------------------------------
async def run():
await test_csis()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -20,6 +20,7 @@ import logging
import os
import struct
import pytest
from unittest.mock import Mock, ANY
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
@@ -763,6 +764,83 @@ async def test_subscribe_notify():
assert not c3._called_3
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe():
[client, server] = LinkedDevices().devices[:2]
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
characteristic2 = Characteristic(
'3234C4F4-3F34-4616-8935-45A50EE05DEB',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3]),
)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic1, characteristic2],
)
server.add_services([service1])
mock1 = Mock()
characteristic1.on('subscription', mock1)
mock2 = Mock()
characteristic2.on('subscription', mock2)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection)
await peer.discover_services()
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(characteristic1.uuid)
assert len(c) == 1
c1 = c[0]
c = peer.get_characteristics_by_uuid(characteristic2.uuid)
assert len(c) == 1
c2 = c[0]
await c1.subscribe()
await async_barrier()
mock1.assert_called_once_with(ANY, True, False)
await c2.subscribe()
await async_barrier()
mock2.assert_called_once_with(ANY, True, False)
mock1.reset_mock()
await c1.unsubscribe()
await async_barrier()
mock1.assert_called_once_with(ANY, False, False)
mock2.reset_mock()
await c2.unsubscribe()
await async_barrier()
mock2.assert_called_once_with(ANY, False, False)
mock1.reset_mock()
await c1.unsubscribe()
await async_barrier()
mock1.assert_not_called()
mock2.reset_mock()
await c2.unsubscribe()
await async_barrier()
mock2.assert_not_called()
mock1.reset_mock()
await c1.unsubscribe(force=True)
await async_barrier()
mock1.assert_called_once_with(ANY, False, False)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
@@ -886,6 +964,7 @@ async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_unsubscribe()
await test_characteristic_encoding()
await test_mtu_exchange()

View File

@@ -71,3 +71,6 @@ class TwoDevices:
# Check the post conditions
assert self.connections[0] is not None
assert self.connections[1] is not None
def __getitem__(self, index: int) -> Device:
return self.devices[index]