Compare commits

..

1 Commits

Author SHA1 Message Date
Charlie Boutier 3514d0fca3 cryptography: bump version to 44.0.3 to fix python parsing
Bug: 404336381
2025-05-07 14:49:06 -07:00
7 changed files with 160 additions and 298 deletions
+1 -1
View File
@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,examples,test,development]" python -m pip install ".[build,test,development]"
- name: Check - name: Check
run: | run: |
invoke project.pre-commit invoke project.pre-commit
+6 -3
View File
@@ -1589,6 +1589,7 @@ class Connection(utils.CompositeEventEmitter):
encryption_key_size: int encryption_key_size: int
authenticated: bool authenticated: bool
sc: bool sc: bool
link_key_type: Optional[int] # [Classic only]
gatt_client: gatt_client.Client gatt_client: gatt_client.Client
pairing_peer_io_capability: Optional[int] pairing_peer_io_capability: Optional[int]
pairing_peer_authentication_requirements: Optional[int] pairing_peer_authentication_requirements: Optional[int]
@@ -1691,6 +1692,7 @@ class Connection(utils.CompositeEventEmitter):
self.encryption_key_size = 0 self.encryption_key_size = 0
self.authenticated = False self.authenticated = False
self.sc = False self.sc = False
self.link_key_type = None
self.att_mtu = ATT_DEFAULT_MTU self.att_mtu = ATT_DEFAULT_MTU
self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.data_length = DEVICE_DEFAULT_DATA_LENGTH
self.gatt_client = gatt_client.Client(self) # Per-connection client self.gatt_client = gatt_client.Client(self) # Per-connection client
@@ -5073,9 +5075,9 @@ class Device(utils.CompositeEventEmitter):
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
) )
pairing_keys = PairingKeys( pairing_keys = PairingKeys()
link_key=PairingKeys.Key(value=link_key, authenticated=authenticated), pairing_keys.link_key = PairingKeys.Key(
link_key_type=key_type, value=link_key, authenticated=authenticated
) )
utils.cancel_on_event( utils.cancel_on_event(
@@ -5085,6 +5087,7 @@ class Device(utils.CompositeEventEmitter):
if connection := self.find_connection_by_bd_addr( if connection := self.find_connection_by_bd_addr(
bd_addr, transport=PhysicalTransport.BR_EDR bd_addr, transport=PhysicalTransport.BR_EDR
): ):
connection.link_key_type = key_type
connection.emit(connection.EVENT_LINK_KEY) connection.emit(connection.EVENT_LINK_KEY)
def add_service(self, service): def add_service(self, service):
+46 -62
View File
@@ -22,15 +22,14 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Any from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing_extensions import Self from typing_extensions import Self
from bumble.colors import color from bumble.colors import color
from bumble import hci from bumble.hci import Address
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device from bumble.device import Device
@@ -43,17 +42,16 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class PairingKeys: class PairingKeys:
@dataclasses.dataclass
class Key: class Key:
value: bytes def __init__(self, value, authenticated=False, ediv=None, rand=None):
authenticated: bool = False self.value = value
ediv: Optional[int] = None self.authenticated = authenticated
rand: Optional[bytes] = None self.ediv = ediv
self.rand = rand
@classmethod @classmethod
def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key: def from_dict(cls, key_dict):
value = bytes.fromhex(key_dict['value']) value = bytes.fromhex(key_dict['value'])
authenticated = key_dict.get('authenticated', False) authenticated = key_dict.get('authenticated', False)
ediv = key_dict.get('ediv') ediv = key_dict.get('ediv')
@@ -63,7 +61,7 @@ class PairingKeys:
return cls(value, authenticated, ediv, rand) return cls(value, authenticated, ediv, rand)
def to_dict(self) -> dict[str, Any]: def to_dict(self):
key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
if self.ediv is not None: if self.ediv is not None:
key_dict['ediv'] = self.ediv key_dict['ediv'] = self.ediv
@@ -72,42 +70,39 @@ class PairingKeys:
return key_dict return key_dict
address_type: Optional[hci.AddressType] = None def __init__(self):
ltk: Optional[Key] = None self.address_type = None
ltk_central: Optional[Key] = None self.ltk = None
ltk_peripheral: Optional[Key] = None self.ltk_central = None
irk: Optional[Key] = None self.ltk_peripheral = None
csrk: Optional[Key] = None self.irk = None
link_key: Optional[Key] = None # Classic self.csrk = None
link_key_type: Optional[int] = None # Classic self.link_key = None # Classic
@classmethod @staticmethod
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]: def key_from_dict(keys_dict, key_name):
key_dict = keys_dict.get(key_name) key_dict = keys_dict.get(key_name)
if key_dict is None: if key_dict is None:
return None return None
return PairingKeys.Key.from_dict(key_dict) return PairingKeys.Key.from_dict(key_dict)
@classmethod @staticmethod
def from_dict(cls, keys_dict: dict[str, Any]) -> PairingKeys: def from_dict(keys_dict):
return PairingKeys( keys = PairingKeys()
address_type=(
hci.AddressType(t)
if (t := keys_dict.get('address_type')) is not None
else None
),
ltk=PairingKeys.key_from_dict(keys_dict, 'ltk'),
ltk_central=PairingKeys.key_from_dict(keys_dict, 'ltk_central'),
ltk_peripheral=PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral'),
irk=PairingKeys.key_from_dict(keys_dict, 'irk'),
csrk=PairingKeys.key_from_dict(keys_dict, 'csrk'),
link_key=PairingKeys.key_from_dict(keys_dict, 'link_key'),
link_key_type=keys_dict.get('link_key_type'),
)
def to_dict(self) -> dict[str, Any]: keys.address_type = keys_dict.get('address_type')
keys: dict[str, Any] = {} keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys
def to_dict(self):
keys = {}
if self.address_type is not None: if self.address_type is not None:
keys['address_type'] = self.address_type keys['address_type'] = self.address_type
@@ -130,12 +125,9 @@ class PairingKeys:
if self.link_key is not None: if self.link_key is not None:
keys['link_key'] = self.link_key.to_dict() keys['link_key'] = self.link_key.to_dict()
if self.link_key_type is not None:
keys['link_key_type'] = self.link_key_type
return keys return keys
def print(self, prefix: str = '') -> None: def print(self, prefix=''):
keys_dict = self.to_dict() keys_dict = self.to_dict()
for container_property, value in keys_dict.items(): for container_property, value in keys_dict.items():
if isinstance(value, dict): if isinstance(value, dict):
@@ -164,28 +156,20 @@ class KeyStore:
all_keys = await self.get_all() all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self) -> list[tuple[bytes, hci.Address]]: async def get_resolving_keys(self):
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
for name, keys in all_keys: for name, keys in all_keys:
if keys.irk is not None: if keys.irk is not None:
resolving_keys.append( if keys.address_type is None:
( address_type = Address.RANDOM_DEVICE_ADDRESS
keys.irk.value, else:
hci.Address( address_type = keys.address_type
name, resolving_keys.append((keys.irk.value, Address(name, address_type)))
(
keys.address_type
if keys.address_type is not None
else hci.Address.RANDOM_DEVICE_ADDRESS
),
),
)
)
return resolving_keys return resolving_keys
async def print(self, prefix: str = '') -> None: async def print(self, prefix=''):
entries = await self.get_all() entries = await self.get_all()
separator = '' separator = ''
for name, keys in entries: for name, keys in entries:
@@ -193,8 +177,8 @@ class KeyStore:
keys.print(prefix=prefix + ' ') keys.print(prefix=prefix + ' ')
separator = '\n' separator = '\n'
@classmethod @staticmethod
def create_for_device(cls, device: Device) -> KeyStore: def create_for_device(device: Device) -> KeyStore:
if device.config.keystore is None: if device.config.keystore is None:
return MemoryKeyStore() return MemoryKeyStore()
@@ -282,9 +266,9 @@ class JsonKeyStore(KeyStore):
filename = params[0] filename = params[0]
# Use a namespace based on the device address # Use a namespace based on the device address
if device.public_address not in (hci.Address.ANY, hci.Address.ANY_RANDOM): if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
namespace = str(device.public_address) namespace = str(device.public_address)
elif device.random_address != hci.Address.ANY_RANDOM: elif device.random_address != Address.ANY_RANDOM:
namespace = str(device.random_address) namespace = str(device.random_address)
else: else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE namespace = JsonKeyStore.DEFAULT_NAMESPACE
+53 -69
View File
@@ -15,7 +15,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
from collections.abc import Awaitable
import grpc import grpc
import logging import logging
@@ -25,7 +24,6 @@ from bumble import hci
from bumble.core import ( from bumble.core import (
PhysicalTransport, PhysicalTransport,
ProtocolError, ProtocolError,
InvalidArgumentError,
) )
import bumble.utils import bumble.utils
from bumble.device import Connection as BumbleConnection, Device from bumble.device import Connection as BumbleConnection, Device
@@ -190,6 +188,35 @@ class PairingDelegate(BasePairingDelegate):
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
LEVEL0: lambda connection: True,
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
),
LEVEL4: lambda connection: connection.encryption
== hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and connection.link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
}
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
LE_LEVEL1: lambda connection: True,
LE_LEVEL2: lambda connection: connection.encryption != 0,
LE_LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated,
LE_LEVEL4: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.sc,
}
class SecurityService(SecurityServicer): class SecurityService(SecurityServicer):
def __init__(self, device: Device, config: Config) -> None: def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter( self.log = utils.BumbleServerLoggerAdapter(
@@ -221,59 +248,6 @@ class SecurityService(SecurityServicer):
self.device.pairing_config_factory = pairing_config_factory self.device.pairing_config_factory = pairing_config_factory
async def _classic_level_reached(
self, level: SecurityLevel, connection: BumbleConnection
) -> bool:
if level == LEVEL0:
return True
if level == LEVEL1:
return connection.encryption == 0 or connection.authenticated
if level == LEVEL2:
return connection.encryption != 0 and connection.authenticated
link_key_type: Optional[int] = None
if (keystore := connection.device.keystore) and (
keys := await keystore.get(str(connection.peer_address))
):
link_key_type = keys.link_key_type
self.log.debug("link_key_type: %d", link_key_type)
if level == LEVEL3:
return (
connection.encryption != 0
and connection.authenticated
and link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
)
)
if level == LEVEL4:
return (
connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE
)
raise InvalidArgumentError(f"Unexpected level {level}")
def _le_level_reached(
self, level: LESecurityLevel, connection: BumbleConnection
) -> bool:
if level == LE_LEVEL1:
return True
if level == LE_LEVEL2:
return connection.encryption != 0
if level == LE_LEVEL3:
return connection.encryption != 0 and connection.authenticated
if level == LE_LEVEL4:
return (
connection.encryption != 0
and connection.authenticated
and connection.sc
)
raise InvalidArgumentError(f"Unexpected level {level}")
@utils.rpc @utils.rpc
async def OnPairing( async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
@@ -316,7 +290,7 @@ class SecurityService(SecurityServicer):
] == oneof ] == oneof
# security level already reached # security level already reached
if await self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(success=empty_pb2.Empty())
# trigger pairing if needed # trigger pairing if needed
@@ -387,7 +361,7 @@ class SecurityService(SecurityServicer):
return SecureResponse(encryption_failure=empty_pb2.Empty()) return SecureResponse(encryption_failure=empty_pb2.Empty())
# security level has been reached ? # security level has been reached ?
if await self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(success=empty_pb2.Empty())
return SecureResponse(not_reached=empty_pb2.Empty()) return SecureResponse(not_reached=empty_pb2.Empty())
@@ -414,10 +388,13 @@ class SecurityService(SecurityServicer):
pair_task: Optional[asyncio.Future[None]] = None pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None: async def authenticate() -> None:
assert connection
if (encryption := connection.encryption) != 0: if (encryption := connection.encryption) != 0:
self.log.debug('Disable encryption...') self.log.debug('Disable encryption...')
with contextlib.suppress(Exception): try:
await connection.encrypt(enable=False) await connection.encrypt(enable=False)
except:
pass
self.log.debug('Disable encryption: done') self.log.debug('Disable encryption: done')
self.log.debug('Authenticate...') self.log.debug('Authenticate...')
@@ -436,13 +413,15 @@ class SecurityService(SecurityServicer):
return wrapper return wrapper
async def try_set_success(*_: Any) -> None: def try_set_success(*_: Any) -> None:
if await self.reached_security_level(connection, level): assert connection
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
async def on_encryption_change(*_: Any) -> None: def on_encryption_change(*_: Any) -> None:
if await self.reached_security_level(connection, level): assert connection
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
elif ( elif (
@@ -457,7 +436,7 @@ class SecurityService(SecurityServicer):
if self.need_pairing(connection, level): if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair()) pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., Union[None, Awaitable[None]]]] = { listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'), 'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'), 'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'), 'connection_authentication_failure': set_failure('authentication_failure'),
@@ -476,7 +455,7 @@ class SecurityService(SecurityServicer):
watcher.on(connection, event, listener) watcher.on(connection, event, listener)
# security level already reached # security level already reached
if await self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty()) return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.debug('Wait for security...') self.log.debug('Wait for security...')
@@ -486,20 +465,24 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any # wait for `authenticate` to finish if any
if authenticate_task is not None: if authenticate_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
with contextlib.suppress(Exception): try:
await authenticate_task # type: ignore await authenticate_task # type: ignore
except:
pass
self.log.debug('Authenticated') self.log.debug('Authenticated')
# wait for `pair` to finish if any # wait for `pair` to finish if any
if pair_task is not None: if pair_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
with contextlib.suppress(Exception): try:
await pair_task # type: ignore await pair_task # type: ignore
except:
pass
self.log.debug('paired') self.log.debug('paired')
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
async def reached_security_level( def reached_security_level(
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool: ) -> bool:
self.log.debug( self.log.debug(
@@ -509,14 +492,15 @@ class SecurityService(SecurityServicer):
'encryption': connection.encryption, 'encryption': connection.encryption,
'authenticated': connection.authenticated, 'authenticated': connection.authenticated,
'sc': connection.sc, 'sc': connection.sc,
'link_key_type': connection.link_key_type,
} }
) )
) )
if isinstance(level, LESecurityLevel): if isinstance(level, LESecurityLevel):
return self._le_level_reached(level, connection) return LE_LEVEL_REACHED[level](connection)
return await self._classic_level_reached(level, connection) return BR_LEVEL_REACHED[level](connection)
def need_pairing(self, connection: BumbleConnection, level: int) -> bool: def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == PhysicalTransport.LE: if connection.transport == PhysicalTransport.LE:
+1 -3
View File
@@ -1380,10 +1380,8 @@ class Session:
ediv=self.ltk_ediv, ediv=self.ltk_ediv,
rand=self.ltk_rand, rand=self.ltk_rand,
) )
if not self.peer_ltk:
logger.error("peer_ltk is None")
peer_ltk_key = PairingKeys.Key( peer_ltk_key = PairingKeys.Key(
value=self.peer_ltk or b'', value=self.peer_ltk,
authenticated=authenticated, authenticated=authenticated,
ediv=self.peer_ediv, ediv=self.peer_ediv,
rand=self.peer_rand, rand=self.peer_rand,
+53 -153
View File
@@ -16,43 +16,28 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import logging
import os
import sys import sys
from dataclasses import dataclass import os
import logging
import ffmpeg from bumble.colors import color
from bumble.device import Device
from bumble.a2dp import ( from bumble.transport import open_transport_or_link
A2DP_MPEG_2_4_AAC_CODEC_TYPE, from bumble.core import PhysicalTransport
A2DP_SBC_CODEC_TYPE,
AacMediaCodecInformation,
AacPacketSource,
SbcMediaCodecInformation,
SbcPacketSource,
make_audio_source_service_sdp_records,
)
from bumble.avdtp import ( from bumble.avdtp import (
find_avdtp_service_with_connection,
AVDTP_AUDIO_MEDIA_TYPE, AVDTP_AUDIO_MEDIA_TYPE,
Listener,
MediaCodecCapabilities, MediaCodecCapabilities,
MediaPacketPump, MediaPacketPump,
Protocol, Protocol,
find_avdtp_service_with_connection, Listener,
)
from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
SbcMediaCodecInformation,
SbcPacketSource,
) )
from bumble.colors import color
from bumble.core import PhysicalTransport
from bumble.device import Device
from bumble.transport import open_transport_or_link
from typing import Dict, Union
@dataclass
class CodecCapabilities:
name: str
sample_rate: str
number_of_channels: str
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -66,147 +51,67 @@ def sdp_records():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avdtp_connection( def codec_capabilities():
read_function, protocol, codec_capabilities: MediaCodecCapabilities # NOTE: this shouldn't be hardcoded, but should be inferred from the input file
): # instead
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu) return MediaCodecCapabilities(
packet_pump = MediaPacketPump(packet_source.packets) media_type=AVDTP_AUDIO_MEDIA_TYPE,
protocol.add_source(codec_capabilities, packet_pump) media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_44100,
channel_mode=SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def stream_packets( def on_avdtp_connection(read_function, protocol):
read_function, protocol, codec_capabilities: MediaCodecCapabilities packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu)
): packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(codec_capabilities(), packet_pump)
# -----------------------------------------------------------------------------
async def stream_packets(read_function, protocol):
# Discover all endpoints on the remote device # Discover all endpoints on the remote device
endpoints = await protocol.discover_remote_endpoints() endpoints = await protocol.discover_remote_endpoints()
for endpoint in endpoints: for endpoint in endpoints:
print('@@@', endpoint) print('@@@', endpoint)
# Select a sink # Select a sink
assert codec_capabilities.media_codec_type in [
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
]
sink = protocol.find_remote_sink_by_codec( sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, codec_capabilities.media_codec_type AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE
) )
if sink is None: if sink is None:
print(color('!!! no Sink found', 'red')) print(color('!!! no SBC sink found', 'red'))
return return
print(f'### Selected sink: {sink.seid}') print(f'### Selected sink: {sink.seid}')
# Stream the packets # Stream the packets
packet_sources = { packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu)
A2DP_SBC_CODEC_TYPE: SbcPacketSource( packet_pump = MediaPacketPump(packet_source.packets)
read_function, protocol.l2cap_channel.peer_mtu source = protocol.add_source(codec_capabilities(), packet_pump)
),
A2DP_MPEG_2_4_AAC_CODEC_TYPE: AacPacketSource(
read_function, protocol.l2cap_channel.peer_mtu
),
}
packet_source = packet_sources[codec_capabilities.media_codec_type]
packet_pump = MediaPacketPump(packet_source.packets) # type: ignore
source = protocol.add_source(codec_capabilities, packet_pump)
stream = await protocol.create_stream(source, sink) stream = await protocol.create_stream(source, sink)
await stream.start() await stream.start()
await asyncio.sleep(60) await asyncio.sleep(5)
await stream.stop()
await asyncio.sleep(5)
await stream.start()
await asyncio.sleep(5)
await stream.stop() await stream.stop()
await stream.close() await stream.close()
# -----------------------------------------------------------------------------
def fetch_codec_informations(filepath) -> MediaCodecCapabilities:
probe = ffmpeg.probe(filepath)
assert 'streams' in probe
streams = probe['streams']
if not streams or len(streams) > 1:
print(streams)
print(color('!!! file not supported', 'red'))
exit()
audio_stream = streams[0]
media_codec_type = None
media_codec_information: Union[
SbcMediaCodecInformation, AacMediaCodecInformation, None
] = None
assert 'codec_name' in audio_stream
codec_name: str = audio_stream['codec_name']
if codec_name == "sbc":
media_codec_type = A2DP_SBC_CODEC_TYPE
sbc_sampling_frequency: Dict[
str, SbcMediaCodecInformation.SamplingFrequency
] = {
'16000': SbcMediaCodecInformation.SamplingFrequency.SF_16000,
'32000': SbcMediaCodecInformation.SamplingFrequency.SF_32000,
'44100': SbcMediaCodecInformation.SamplingFrequency.SF_44100,
'48000': SbcMediaCodecInformation.SamplingFrequency.SF_48000,
}
sbc_channel_mode: Dict[int, SbcMediaCodecInformation.ChannelMode] = {
1: SbcMediaCodecInformation.ChannelMode.MONO,
2: SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
}
assert 'sample_rate' in audio_stream
assert 'channels' in audio_stream
media_codec_information = SbcMediaCodecInformation(
sampling_frequency=sbc_sampling_frequency[audio_stream['sample_rate']],
channel_mode=sbc_channel_mode[audio_stream['channels']],
block_length=SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
)
elif codec_name == "aac":
media_codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE
object_type: Dict[str, AacMediaCodecInformation.ObjectType] = {
'LC': AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
'LTP': AacMediaCodecInformation.ObjectType.MPEG_4_AAC_LTP,
'SSR': AacMediaCodecInformation.ObjectType.MPEG_4_AAC_SCALABLE,
}
aac_sampling_frequency: Dict[
str, AacMediaCodecInformation.SamplingFrequency
] = {
'44100': AacMediaCodecInformation.SamplingFrequency.SF_44100,
'48000': AacMediaCodecInformation.SamplingFrequency.SF_48000,
}
aac_channel_mode: Dict[int, AacMediaCodecInformation.Channels] = {
1: AacMediaCodecInformation.Channels.MONO,
2: AacMediaCodecInformation.Channels.STEREO,
}
assert 'profile' in audio_stream
assert 'sample_rate' in audio_stream
assert 'channels' in audio_stream
media_codec_information = AacMediaCodecInformation(
object_type=object_type[audio_stream['profile']],
sampling_frequency=aac_sampling_frequency[audio_stream['sample_rate']],
channels=aac_channel_mode[audio_stream['channels']],
vbr=1,
bitrate=128000,
)
else:
print(color('!!! codec not supported, only aac & sbc are supported', 'red'))
exit()
assert media_codec_type is not None
assert media_codec_information is not None
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=media_codec_type,
media_codec_information=media_codec_information,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main() -> None:
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
'Usage: run_a2dp_source.py <device-config> <transport-spec> <audio-file> ' 'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> '
'[<bluetooth-address>]' '[<bluetooth-address>]'
) )
print( print(
@@ -230,13 +135,11 @@ async def main() -> None:
# Start # Start
await device.power_on() await device.power_on()
with open(sys.argv[3], 'rb') as audio_file: with open(sys.argv[3], 'rb') as sbc_file:
# NOTE: this should be using asyncio file reading, but blocking reads are # NOTE: this should be using asyncio file reading, but blocking reads are
# good enough for testing # good enough for testing
async def read(byte_count): async def read(byte_count):
return audio_file.read(byte_count) return sbc_file.read(byte_count)
codec_capabilities = fetch_codec_informations(sys.argv[3])
if len(sys.argv) > 4: if len(sys.argv) > 4:
# Connect to a peer # Connect to a peer
@@ -267,15 +170,12 @@ async def main() -> None:
protocol = await Protocol.connect(connection, avdtp_version) protocol = await Protocol.connect(connection, avdtp_version)
# Start streaming # Start streaming
await stream_packets(read, protocol, codec_capabilities) await stream_packets(read, protocol)
else: else:
# Create a listener to wait for AVDTP connections # Create a listener to wait for AVDTP connections
listener = Listener.for_device(device=device, version=(1, 2)) listener = Listener.for_device(device=device, version=(1, 2))
listener.on( listener.on(
'connection', 'connection', lambda protocol: on_avdtp_connection(read, protocol)
lambda protocol: on_avdtp_connection(
read, protocol, codec_capabilities
),
) )
# Become connectable and wait for a connection # Become connectable and wait for a connection
-7
View File
@@ -55,9 +55,6 @@ development = [
"types-invoke >= 1.7.3", "types-invoke >= 1.7.3",
"types-protobuf >= 4.21.0", "types-protobuf >= 4.21.0",
] ]
examples = [
"ffmpeg-python == 0.2.0",
]
avatar = [ avatar = [
"pandora-avatar == 0.0.10", "pandora-avatar == 0.0.10",
"rootcanal == 1.11.1 ; python_version>='3.10'", "rootcanal == 1.11.1 ; python_version>='3.10'",
@@ -187,10 +184,6 @@ ignore_missing_imports = true
module = "construct.*" module = "construct.*"
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "ffmpeg.*"
ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "grpc.*" module = "grpc.*"
ignore_missing_imports = true ignore_missing_imports = true