Complete CSIP and CAP

Also add random address generation functions.
This commit is contained in:
Josh Wu
2023-12-14 23:52:04 +08:00
parent a286700239
commit 87c76a4a0e
8 changed files with 385 additions and 27 deletions

View File

@@ -100,6 +100,16 @@ class EccKey:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def generate_prand() -> bytes:
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes: def xor(x: bytes, y: bytes) -> bytes:
assert len(x) == len(y) assert len(x) == len(y)

View File

@@ -368,9 +368,12 @@ class TemplateService(Service):
UUID: UUID UUID: UUID
def __init__( def __init__(
self, characteristics: List[Characteristic], primary: bool = True self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: List[Service] = [],
) -> None: ) -> None:
super().__init__(self.UUID, characteristics, primary) super().__init__(self.UUID, characteristics, primary, included_services)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

52
bumble/profiles/cap.py Normal file
View File

@@ -0,0 +1,52 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -21,6 +21,9 @@ import enum
import struct import struct
from typing import Optional from typing import Optional
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt from bumble import gatt
from bumble import gatt_client from bumble import gatt_client
@@ -43,9 +46,43 @@ class MemberLock(enum.IntEnum):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Crypto Toolbox
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# TODO: Implement RSI Generator def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -54,6 +91,7 @@ class MemberLock(enum.IntEnum):
class CoordinatedSetIdentificationService(gatt.TemplateService): class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None set_member_lock_characteristic: Optional[gatt.Characteristic] = None
@@ -62,19 +100,21 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
def __init__( def __init__(
self, self,
set_identity_resolving_key: bytes, set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None, coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None, set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None, set_member_rank: Optional[int] = None,
) -> None: ) -> None:
characteristics = [] characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic( self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC, uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY, | gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE, permissions=gatt.Characteristic.Permissions.READABLE,
# TODO: Implement encrypted SIRK reader. value=gatt.CharacteristicValue(read=self.on_sirk_read),
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
) )
characteristics.append(self.set_identity_resolving_key_characteristic) characteristics.append(self.set_identity_resolving_key_characteristic)
@@ -112,6 +152,24 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
super().__init__(characteristics) super().__init__(characteristics)
def on_sirk_read(self, _connection: device.Connection) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key
else:
raise NotImplementedError('TODO: Pending async Characteristic read.')
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Client # Client

View File

@@ -0,0 +1,116 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
import secrets
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.hci import (
Address,
OwnAddressType,
HCI_LE_Set_Extended_Advertising_Parameters_Command,
)
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_cig_setup.py <config-file>'
'<transport-spec-for-device-1> <transport-spec-for-device-2>'
)
print(
'example: run_cig_setup.py device1.json'
'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402'
)
return
print('<<< connecting to HCI...')
hci_transports = await asyncio.gather(
open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3])
)
print('<<< connected')
devices = [
Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
for hci_transport in hci_transports
]
sirk = secrets.token_bytes(16)
for i, device in enumerate(devices):
device.random_address = Address(secrets.token_bytes(6))
await device.power_on()
csis = CoordinatedSetIdentificationService(
set_identity_resolving_key=sirk,
set_identity_resolving_key_type=SirkType.PLAINTEXT,
coordinated_set_size=2,
)
device.add_service(CommonAudioServiceService(csis))
advertising_data = (
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(f'Bumble LE Audio-{i}', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(CoordinatedSetIdentificationService.UUID),
),
]
)
)
+ csis.get_advertising_data()
)
await device.start_extended_advertising(
advertising_properties=(
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
),
own_address_type=OwnAddressType.RANDOM,
advertising_data=advertising_data,
)
await asyncio.gather(
*[hci_transport.source.terminated for hci_transport in hci_transports]
)
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -20,6 +20,7 @@ import logging
import sys import sys
import os import os
import struct import struct
import secrets
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import Device, CisLink from bumble.device import Device, CisLink
from bumble.hci import ( from bumble.hci import (
@@ -39,6 +40,8 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesService,
AudioStreamControlService, AudioStreamControlService,
) )
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -60,6 +63,11 @@ async def main() -> None:
await device.power_on() await device.power_on()
csis = CoordinatedSetIdentificationService(
set_identity_resolving_key=secrets.token_bytes(16),
set_identity_resolving_key_type=SirkType.PLAINTEXT,
)
device.add_service(CommonAudioServiceService(csis))
device.add_service( device.add_service(
PublishedAudioCapabilitiesService( PublishedAudioCapabilitiesService(
supported_source_context=ContextType.PROHIBITED, supported_source_context=ContextType.PROHIBITED,
@@ -108,29 +116,32 @@ async def main() -> None:
device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2]))
advertising_data = bytes( advertising_data = (
AdvertisingData( bytes(
[ AdvertisingData(
( [
AdvertisingData.COMPLETE_LOCAL_NAME, (
bytes('Bumble LE Audio', 'utf-8'), AdvertisingData.COMPLETE_LOCAL_NAME,
), bytes('Bumble LE Audio', 'utf-8'),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
), ),
), (
( AdvertisingData.FLAGS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(
bytes(PublishedAudioCapabilitiesService.UUID), [
), AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
] | AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
),
]
)
) )
+ csis.get_advertising_data()
) )
subprocess = await asyncio.create_subprocess_shell( subprocess = await asyncio.create_subprocess_shell(
f'dlc3 | ffplay pipe:0', f'dlc3 | ffplay pipe:0',

71
tests/cap_test.py Normal file
View File

@@ -0,0 +1,71 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import pytest
import logging
from bumble import device
from bumble import gatt
from bumble.profiles import cap
from bumble.profiles import csip
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cas():
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices()
devices[0].add_service(
cap.CommonAudioServiceService(
csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK,
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
)
)
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
cas_client = await peer.discover_service_and_create_proxy(
cap.CommonAudioServiceServiceProxy
)
included_services = await peer.discover_included_services(cas_client.service_proxy)
assert any(
service.uuid == gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
for service in included_services
)
# -----------------------------------------------------------------------------
async def run():
await test_cas()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -31,6 +31,41 @@ from .test_utils import TwoDevices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def test_s1():
assert (
csip.s1(b'SIRKenc'[::-1])
== bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1]
)
# -----------------------------------------------------------------------------
def test_k1():
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
SALT = csip.s1(b'SIRKenc'[::-1])
P = b'csis'[::-1]
assert (
csip.k1(K, SALT, P)
== bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1]
)
# -----------------------------------------------------------------------------
def test_sih():
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
PRAND = bytes.fromhex('69f563')[::-1]
assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1]
# -----------------------------------------------------------------------------
def test_sef():
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
assert (
csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_csis(): async def test_csis():
@@ -40,6 +75,7 @@ async def test_csis():
devices[0].add_service( devices[0].add_service(
csip.CoordinatedSetIdentificationService( csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK, set_identity_resolving_key=SIRK,
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
coordinated_set_size=2, coordinated_set_size=2,
set_member_lock=csip.MemberLock.UNLOCKED, set_member_lock=csip.MemberLock.UNLOCKED,
set_member_rank=0, set_member_rank=0,
@@ -65,6 +101,7 @@ async def test_csis():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(): async def run():
test_sih()
await test_csis() await test_csis()