Compare commits

...

4 Commits

Author SHA1 Message Date
Charlie Boutier
7237619d3b A2DP example: Codec selection based on file type
Currently support SBC and AAC
2025-05-08 14:24:42 -07:00
Slvr
a88a034ce2 cryptography: bump version to 44.0.3 to fix python parsing (#684)
Bug: 404336381
2025-05-08 08:28:33 -07:00
zxzxwu
6b2cd1147d Merge pull request #682 from zxzxwu/linkkey
Move connection.link_key_type to keystore
2025-05-08 11:23:28 +08:00
Josh Wu
bb8dcaf63e Move connection.link_key_type to keystore 2025-05-06 02:11:25 +08:00
7 changed files with 298 additions and 160 deletions

View File

@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development]" python -m pip install ".[build,examples,test,development]"
- name: Check - name: Check
run: | run: |
invoke project.pre-commit invoke project.pre-commit

View File

@@ -1589,7 +1589,6 @@ class Connection(utils.CompositeEventEmitter):
encryption_key_size: int encryption_key_size: int
authenticated: bool authenticated: bool
sc: bool sc: bool
link_key_type: Optional[int] # [Classic only]
gatt_client: gatt_client.Client gatt_client: gatt_client.Client
pairing_peer_io_capability: Optional[int] pairing_peer_io_capability: Optional[int]
pairing_peer_authentication_requirements: Optional[int] pairing_peer_authentication_requirements: Optional[int]
@@ -1692,7 +1691,6 @@ class Connection(utils.CompositeEventEmitter):
self.encryption_key_size = 0 self.encryption_key_size = 0
self.authenticated = False self.authenticated = False
self.sc = False self.sc = False
self.link_key_type = None
self.att_mtu = ATT_DEFAULT_MTU self.att_mtu = ATT_DEFAULT_MTU
self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.data_length = DEVICE_DEFAULT_DATA_LENGTH
self.gatt_client = gatt_client.Client(self) # Per-connection client self.gatt_client = gatt_client.Client(self) # Per-connection client
@@ -5075,9 +5073,9 @@ class Device(utils.CompositeEventEmitter):
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
) )
pairing_keys = PairingKeys() pairing_keys = PairingKeys(
pairing_keys.link_key = PairingKeys.Key( link_key=PairingKeys.Key(value=link_key, authenticated=authenticated),
value=link_key, authenticated=authenticated link_key_type=key_type,
) )
utils.cancel_on_event( utils.cancel_on_event(
@@ -5087,7 +5085,6 @@ class Device(utils.CompositeEventEmitter):
if connection := self.find_connection_by_bd_addr( if connection := self.find_connection_by_bd_addr(
bd_addr, transport=PhysicalTransport.BR_EDR bd_addr, transport=PhysicalTransport.BR_EDR
): ):
connection.link_key_type = key_type
connection.emit(connection.EVENT_LINK_KEY) connection.emit(connection.EVENT_LINK_KEY)
def add_service(self, service): def add_service(self, service):

View File

@@ -22,14 +22,15 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Any
from typing_extensions import Self from typing_extensions import Self
from bumble.colors import color from bumble.colors import color
from bumble.hci import Address from bumble import hci
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device from bumble.device import Device
@@ -42,16 +43,17 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class PairingKeys: class PairingKeys:
@dataclasses.dataclass
class Key: class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None): value: bytes
self.value = value authenticated: bool = False
self.authenticated = authenticated ediv: Optional[int] = None
self.ediv = ediv rand: Optional[bytes] = None
self.rand = rand
@classmethod @classmethod
def from_dict(cls, key_dict): def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key:
value = bytes.fromhex(key_dict['value']) value = bytes.fromhex(key_dict['value'])
authenticated = key_dict.get('authenticated', False) authenticated = key_dict.get('authenticated', False)
ediv = key_dict.get('ediv') ediv = key_dict.get('ediv')
@@ -61,7 +63,7 @@ class PairingKeys:
return cls(value, authenticated, ediv, rand) return cls(value, authenticated, ediv, rand)
def to_dict(self): def to_dict(self) -> dict[str, Any]:
key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
if self.ediv is not None: if self.ediv is not None:
key_dict['ediv'] = self.ediv key_dict['ediv'] = self.ediv
@@ -70,39 +72,42 @@ class PairingKeys:
return key_dict return key_dict
def __init__(self): address_type: Optional[hci.AddressType] = None
self.address_type = None ltk: Optional[Key] = None
self.ltk = None ltk_central: Optional[Key] = None
self.ltk_central = None ltk_peripheral: Optional[Key] = None
self.ltk_peripheral = None irk: Optional[Key] = None
self.irk = None csrk: Optional[Key] = None
self.csrk = None link_key: Optional[Key] = None # Classic
self.link_key = None # Classic link_key_type: Optional[int] = None # Classic
@staticmethod @classmethod
def key_from_dict(keys_dict, key_name): def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]:
key_dict = keys_dict.get(key_name) key_dict = keys_dict.get(key_name)
if key_dict is None: if key_dict is None:
return None return None
return PairingKeys.Key.from_dict(key_dict) return PairingKeys.Key.from_dict(key_dict)
@staticmethod @classmethod
def from_dict(keys_dict): def from_dict(cls, keys_dict: dict[str, Any]) -> PairingKeys:
keys = PairingKeys() return PairingKeys(
address_type=(
hci.AddressType(t)
if (t := keys_dict.get('address_type')) is not None
else None
),
ltk=PairingKeys.key_from_dict(keys_dict, 'ltk'),
ltk_central=PairingKeys.key_from_dict(keys_dict, 'ltk_central'),
ltk_peripheral=PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral'),
irk=PairingKeys.key_from_dict(keys_dict, 'irk'),
csrk=PairingKeys.key_from_dict(keys_dict, 'csrk'),
link_key=PairingKeys.key_from_dict(keys_dict, 'link_key'),
link_key_type=keys_dict.get('link_key_type'),
)
keys.address_type = keys_dict.get('address_type') def to_dict(self) -> dict[str, Any]:
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') keys: dict[str, Any] = {}
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys
def to_dict(self):
keys = {}
if self.address_type is not None: if self.address_type is not None:
keys['address_type'] = self.address_type keys['address_type'] = self.address_type
@@ -125,9 +130,12 @@ class PairingKeys:
if self.link_key is not None: if self.link_key is not None:
keys['link_key'] = self.link_key.to_dict() keys['link_key'] = self.link_key.to_dict()
if self.link_key_type is not None:
keys['link_key_type'] = self.link_key_type
return keys return keys
def print(self, prefix=''): def print(self, prefix: str = '') -> None:
keys_dict = self.to_dict() keys_dict = self.to_dict()
for container_property, value in keys_dict.items(): for container_property, value in keys_dict.items():
if isinstance(value, dict): if isinstance(value, dict):
@@ -156,20 +164,28 @@ class KeyStore:
all_keys = await self.get_all() all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self): async def get_resolving_keys(self) -> list[tuple[bytes, hci.Address]]:
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
for name, keys in all_keys: for name, keys in all_keys:
if keys.irk is not None: if keys.irk is not None:
if keys.address_type is None: resolving_keys.append(
address_type = Address.RANDOM_DEVICE_ADDRESS (
else: keys.irk.value,
address_type = keys.address_type hci.Address(
resolving_keys.append((keys.irk.value, Address(name, address_type))) name,
(
keys.address_type
if keys.address_type is not None
else hci.Address.RANDOM_DEVICE_ADDRESS
),
),
)
)
return resolving_keys return resolving_keys
async def print(self, prefix=''): async def print(self, prefix: str = '') -> None:
entries = await self.get_all() entries = await self.get_all()
separator = '' separator = ''
for name, keys in entries: for name, keys in entries:
@@ -177,8 +193,8 @@ class KeyStore:
keys.print(prefix=prefix + ' ') keys.print(prefix=prefix + ' ')
separator = '\n' separator = '\n'
@staticmethod @classmethod
def create_for_device(device: Device) -> KeyStore: def create_for_device(cls, device: Device) -> KeyStore:
if device.config.keystore is None: if device.config.keystore is None:
return MemoryKeyStore() return MemoryKeyStore()
@@ -266,9 +282,9 @@ class JsonKeyStore(KeyStore):
filename = params[0] filename = params[0]
# Use a namespace based on the device address # Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM): if device.public_address not in (hci.Address.ANY, hci.Address.ANY_RANDOM):
namespace = str(device.public_address) namespace = str(device.public_address)
elif device.random_address != Address.ANY_RANDOM: elif device.random_address != hci.Address.ANY_RANDOM:
namespace = str(device.random_address) namespace = str(device.random_address)
else: else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE namespace = JsonKeyStore.DEFAULT_NAMESPACE

View File

@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
from collections.abc import Awaitable
import grpc import grpc
import logging import logging
@@ -24,6 +25,7 @@ from bumble import hci
from bumble.core import ( from bumble.core import (
PhysicalTransport, PhysicalTransport,
ProtocolError, ProtocolError,
InvalidArgumentError,
) )
import bumble.utils import bumble.utils
from bumble.device import Connection as BumbleConnection, Device from bumble.device import Connection as BumbleConnection, Device
@@ -188,35 +190,6 @@ class PairingDelegate(BasePairingDelegate):
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
LEVEL0: lambda connection: True,
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
),
LEVEL4: lambda connection: connection.encryption
== hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and connection.link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
}
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
LE_LEVEL1: lambda connection: True,
LE_LEVEL2: lambda connection: connection.encryption != 0,
LE_LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated,
LE_LEVEL4: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.sc,
}
class SecurityService(SecurityServicer): class SecurityService(SecurityServicer):
def __init__(self, device: Device, config: Config) -> None: def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter( self.log = utils.BumbleServerLoggerAdapter(
@@ -248,6 +221,59 @@ class SecurityService(SecurityServicer):
self.device.pairing_config_factory = pairing_config_factory self.device.pairing_config_factory = pairing_config_factory
async def _classic_level_reached(
self, level: SecurityLevel, connection: BumbleConnection
) -> bool:
if level == LEVEL0:
return True
if level == LEVEL1:
return connection.encryption == 0 or connection.authenticated
if level == LEVEL2:
return connection.encryption != 0 and connection.authenticated
link_key_type: Optional[int] = None
if (keystore := connection.device.keystore) and (
keys := await keystore.get(str(connection.peer_address))
):
link_key_type = keys.link_key_type
self.log.debug("link_key_type: %d", link_key_type)
if level == LEVEL3:
return (
connection.encryption != 0
and connection.authenticated
and link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
)
)
if level == LEVEL4:
return (
connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE
)
raise InvalidArgumentError(f"Unexpected level {level}")
def _le_level_reached(
self, level: LESecurityLevel, connection: BumbleConnection
) -> bool:
if level == LE_LEVEL1:
return True
if level == LE_LEVEL2:
return connection.encryption != 0
if level == LE_LEVEL3:
return connection.encryption != 0 and connection.authenticated
if level == LE_LEVEL4:
return (
connection.encryption != 0
and connection.authenticated
and connection.sc
)
raise InvalidArgumentError(f"Unexpected level {level}")
@utils.rpc @utils.rpc
async def OnPairing( async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
@@ -290,7 +316,7 @@ class SecurityService(SecurityServicer):
] == oneof ] == oneof
# security level already reached # security level already reached
if self.reached_security_level(connection, level): if await self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(success=empty_pb2.Empty())
# trigger pairing if needed # trigger pairing if needed
@@ -361,7 +387,7 @@ class SecurityService(SecurityServicer):
return SecureResponse(encryption_failure=empty_pb2.Empty()) return SecureResponse(encryption_failure=empty_pb2.Empty())
# security level has been reached ? # security level has been reached ?
if self.reached_security_level(connection, level): if await self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(success=empty_pb2.Empty())
return SecureResponse(not_reached=empty_pb2.Empty()) return SecureResponse(not_reached=empty_pb2.Empty())
@@ -388,13 +414,10 @@ class SecurityService(SecurityServicer):
pair_task: Optional[asyncio.Future[None]] = None pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None: async def authenticate() -> None:
assert connection
if (encryption := connection.encryption) != 0: if (encryption := connection.encryption) != 0:
self.log.debug('Disable encryption...') self.log.debug('Disable encryption...')
try: with contextlib.suppress(Exception):
await connection.encrypt(enable=False) await connection.encrypt(enable=False)
except:
pass
self.log.debug('Disable encryption: done') self.log.debug('Disable encryption: done')
self.log.debug('Authenticate...') self.log.debug('Authenticate...')
@@ -413,15 +436,13 @@ class SecurityService(SecurityServicer):
return wrapper return wrapper
def try_set_success(*_: Any) -> None: async def try_set_success(*_: Any) -> None:
assert connection if await self.reached_security_level(connection, level):
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None: async def on_encryption_change(*_: Any) -> None:
assert connection if await self.reached_security_level(connection, level):
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
elif ( elif (
@@ -436,7 +457,7 @@ class SecurityService(SecurityServicer):
if self.need_pairing(connection, level): if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair()) pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = { listeners: Dict[str, Callable[..., Union[None, Awaitable[None]]]] = {
'disconnection': set_failure('connection_died'), 'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'), 'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'), 'connection_authentication_failure': set_failure('authentication_failure'),
@@ -455,7 +476,7 @@ class SecurityService(SecurityServicer):
watcher.on(connection, event, listener) watcher.on(connection, event, listener)
# security level already reached # security level already reached
if self.reached_security_level(connection, level): if await self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty()) return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.debug('Wait for security...') self.log.debug('Wait for security...')
@@ -465,24 +486,20 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any # wait for `authenticate` to finish if any
if authenticate_task is not None: if authenticate_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
try: with contextlib.suppress(Exception):
await authenticate_task # type: ignore await authenticate_task # type: ignore
except:
pass
self.log.debug('Authenticated') self.log.debug('Authenticated')
# wait for `pair` to finish if any # wait for `pair` to finish if any
if pair_task is not None: if pair_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
try: with contextlib.suppress(Exception):
await pair_task # type: ignore await pair_task # type: ignore
except:
pass
self.log.debug('paired') self.log.debug('paired')
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
def reached_security_level( async def reached_security_level(
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool: ) -> bool:
self.log.debug( self.log.debug(
@@ -492,15 +509,14 @@ class SecurityService(SecurityServicer):
'encryption': connection.encryption, 'encryption': connection.encryption,
'authenticated': connection.authenticated, 'authenticated': connection.authenticated,
'sc': connection.sc, 'sc': connection.sc,
'link_key_type': connection.link_key_type,
} }
) )
) )
if isinstance(level, LESecurityLevel): if isinstance(level, LESecurityLevel):
return LE_LEVEL_REACHED[level](connection) return self._le_level_reached(level, connection)
return BR_LEVEL_REACHED[level](connection) return await self._classic_level_reached(level, connection)
def need_pairing(self, connection: BumbleConnection, level: int) -> bool: def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == PhysicalTransport.LE: if connection.transport == PhysicalTransport.LE:

View File

@@ -1380,8 +1380,10 @@ class Session:
ediv=self.ltk_ediv, ediv=self.ltk_ediv,
rand=self.ltk_rand, rand=self.ltk_rand,
) )
if not self.peer_ltk:
logger.error("peer_ltk is None")
peer_ltk_key = PairingKeys.Key( peer_ltk_key = PairingKeys.Key(
value=self.peer_ltk, value=self.peer_ltk or b'',
authenticated=authenticated, authenticated=authenticated,
ediv=self.peer_ediv, ediv=self.peer_ediv,
rand=self.peer_rand, rand=self.peer_rand,

View File

@@ -16,28 +16,43 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import sys
import os
import logging import logging
import os
import sys
from dataclasses import dataclass
from bumble.colors import color import ffmpeg
from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.a2dp import (
from bumble.core import PhysicalTransport A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_SBC_CODEC_TYPE,
AacMediaCodecInformation,
AacPacketSource,
SbcMediaCodecInformation,
SbcPacketSource,
make_audio_source_service_sdp_records,
)
from bumble.avdtp import ( from bumble.avdtp import (
find_avdtp_service_with_connection,
AVDTP_AUDIO_MEDIA_TYPE, AVDTP_AUDIO_MEDIA_TYPE,
Listener,
MediaCodecCapabilities, MediaCodecCapabilities,
MediaPacketPump, MediaPacketPump,
Protocol, Protocol,
Listener, find_avdtp_service_with_connection,
)
from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
SbcMediaCodecInformation,
SbcPacketSource,
) )
from bumble.colors import color
from bumble.core import PhysicalTransport
from bumble.device import Device
from bumble.transport import open_transport_or_link
from typing import Dict, Union
@dataclass
class CodecCapabilities:
name: str
sample_rate: str
number_of_channels: str
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -51,67 +66,147 @@ def sdp_records():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def codec_capabilities(): def on_avdtp_connection(
# NOTE: this shouldn't be hardcoded, but should be inferred from the input file read_function, protocol, codec_capabilities: MediaCodecCapabilities
# instead ):
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_44100,
channel_mode=SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
# -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu) packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu)
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(codec_capabilities(), packet_pump) protocol.add_source(codec_capabilities, packet_pump)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def stream_packets(read_function, protocol): async def stream_packets(
read_function, protocol, codec_capabilities: MediaCodecCapabilities
):
# Discover all endpoints on the remote device # Discover all endpoints on the remote device
endpoints = await protocol.discover_remote_endpoints() endpoints = await protocol.discover_remote_endpoints()
for endpoint in endpoints: for endpoint in endpoints:
print('@@@', endpoint) print('@@@', endpoint)
# Select a sink # Select a sink
assert codec_capabilities.media_codec_type in [
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
]
sink = protocol.find_remote_sink_by_codec( sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE AVDTP_AUDIO_MEDIA_TYPE, codec_capabilities.media_codec_type
) )
if sink is None: if sink is None:
print(color('!!! no SBC sink found', 'red')) print(color('!!! no Sink found', 'red'))
return return
print(f'### Selected sink: {sink.seid}') print(f'### Selected sink: {sink.seid}')
# Stream the packets # Stream the packets
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.peer_mtu) packet_sources = {
packet_pump = MediaPacketPump(packet_source.packets) A2DP_SBC_CODEC_TYPE: SbcPacketSource(
source = protocol.add_source(codec_capabilities(), packet_pump) read_function, protocol.l2cap_channel.peer_mtu
),
A2DP_MPEG_2_4_AAC_CODEC_TYPE: AacPacketSource(
read_function, protocol.l2cap_channel.peer_mtu
),
}
packet_source = packet_sources[codec_capabilities.media_codec_type]
packet_pump = MediaPacketPump(packet_source.packets) # type: ignore
source = protocol.add_source(codec_capabilities, packet_pump)
stream = await protocol.create_stream(source, sink) stream = await protocol.create_stream(source, sink)
await stream.start() await stream.start()
await asyncio.sleep(5) await asyncio.sleep(60)
await stream.stop()
await asyncio.sleep(5)
await stream.start()
await asyncio.sleep(5)
await stream.stop() await stream.stop()
await stream.close() await stream.close()
# -----------------------------------------------------------------------------
def fetch_codec_informations(filepath) -> MediaCodecCapabilities:
probe = ffmpeg.probe(filepath)
assert 'streams' in probe
streams = probe['streams']
if not streams or len(streams) > 1:
print(streams)
print(color('!!! file not supported', 'red'))
exit()
audio_stream = streams[0]
media_codec_type = None
media_codec_information: Union[
SbcMediaCodecInformation, AacMediaCodecInformation, None
] = None
assert 'codec_name' in audio_stream
codec_name: str = audio_stream['codec_name']
if codec_name == "sbc":
media_codec_type = A2DP_SBC_CODEC_TYPE
sbc_sampling_frequency: Dict[
str, SbcMediaCodecInformation.SamplingFrequency
] = {
'16000': SbcMediaCodecInformation.SamplingFrequency.SF_16000,
'32000': SbcMediaCodecInformation.SamplingFrequency.SF_32000,
'44100': SbcMediaCodecInformation.SamplingFrequency.SF_44100,
'48000': SbcMediaCodecInformation.SamplingFrequency.SF_48000,
}
sbc_channel_mode: Dict[int, SbcMediaCodecInformation.ChannelMode] = {
1: SbcMediaCodecInformation.ChannelMode.MONO,
2: SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
}
assert 'sample_rate' in audio_stream
assert 'channels' in audio_stream
media_codec_information = SbcMediaCodecInformation(
sampling_frequency=sbc_sampling_frequency[audio_stream['sample_rate']],
channel_mode=sbc_channel_mode[audio_stream['channels']],
block_length=SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
)
elif codec_name == "aac":
media_codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE
object_type: Dict[str, AacMediaCodecInformation.ObjectType] = {
'LC': AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
'LTP': AacMediaCodecInformation.ObjectType.MPEG_4_AAC_LTP,
'SSR': AacMediaCodecInformation.ObjectType.MPEG_4_AAC_SCALABLE,
}
aac_sampling_frequency: Dict[
str, AacMediaCodecInformation.SamplingFrequency
] = {
'44100': AacMediaCodecInformation.SamplingFrequency.SF_44100,
'48000': AacMediaCodecInformation.SamplingFrequency.SF_48000,
}
aac_channel_mode: Dict[int, AacMediaCodecInformation.Channels] = {
1: AacMediaCodecInformation.Channels.MONO,
2: AacMediaCodecInformation.Channels.STEREO,
}
assert 'profile' in audio_stream
assert 'sample_rate' in audio_stream
assert 'channels' in audio_stream
media_codec_information = AacMediaCodecInformation(
object_type=object_type[audio_stream['profile']],
sampling_frequency=aac_sampling_frequency[audio_stream['sample_rate']],
channels=aac_channel_mode[audio_stream['channels']],
vbr=1,
bitrate=128000,
)
else:
print(color('!!! codec not supported, only aac & sbc are supported', 'red'))
exit()
assert media_codec_type is not None
assert media_codec_information is not None
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=media_codec_type,
media_codec_information=media_codec_information,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main() -> None:
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> ' 'Usage: run_a2dp_source.py <device-config> <transport-spec> <audio-file> '
'[<bluetooth-address>]' '[<bluetooth-address>]'
) )
print( print(
@@ -135,11 +230,13 @@ async def main() -> None:
# Start # Start
await device.power_on() await device.power_on()
with open(sys.argv[3], 'rb') as sbc_file: with open(sys.argv[3], 'rb') as audio_file:
# NOTE: this should be using asyncio file reading, but blocking reads are # NOTE: this should be using asyncio file reading, but blocking reads are
# good enough for testing # good enough for testing
async def read(byte_count): async def read(byte_count):
return sbc_file.read(byte_count) return audio_file.read(byte_count)
codec_capabilities = fetch_codec_informations(sys.argv[3])
if len(sys.argv) > 4: if len(sys.argv) > 4:
# Connect to a peer # Connect to a peer
@@ -170,12 +267,15 @@ async def main() -> None:
protocol = await Protocol.connect(connection, avdtp_version) protocol = await Protocol.connect(connection, avdtp_version)
# Start streaming # Start streaming
await stream_packets(read, protocol) await stream_packets(read, protocol, codec_capabilities)
else: else:
# Create a listener to wait for AVDTP connections # Create a listener to wait for AVDTP connections
listener = Listener.for_device(device=device, version=(1, 2)) listener = Listener.for_device(device=device, version=(1, 2))
listener.on( listener.on(
'connection', lambda protocol: on_avdtp_connection(read, protocol) 'connection',
lambda protocol: on_avdtp_connection(
read, protocol, codec_capabilities
),
) )
# Become connectable and wait for a connection # Become connectable and wait for a connection

View File

@@ -13,11 +13,11 @@ dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'", "aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'", "appdirs >= 1.4; platform_system!='Emscripten'",
"click >= 8.1.3; platform_system!='Emscripten'", "click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 39; platform_system!='Emscripten'", "cryptography >= 44.0.3; platform_system!='Emscripten'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the # Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being # versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch. # completely unable to import the package in case of version mismatch.
"cryptography >= 39.0; platform_system=='Emscripten'", "cryptography >= 44.0.3; platform_system=='Emscripten'",
"grpcio >= 1.62.1; platform_system!='Emscripten'", "grpcio >= 1.62.1; platform_system!='Emscripten'",
"humanize >= 4.6.0; platform_system!='Emscripten'", "humanize >= 4.6.0; platform_system!='Emscripten'",
"libusb1 >= 2.0.1; platform_system!='Emscripten'", "libusb1 >= 2.0.1; platform_system!='Emscripten'",
@@ -55,6 +55,9 @@ development = [
"types-invoke >= 1.7.3", "types-invoke >= 1.7.3",
"types-protobuf >= 4.21.0", "types-protobuf >= 4.21.0",
] ]
examples = [
"ffmpeg-python == 0.2.0",
]
avatar = [ avatar = [
"pandora-avatar == 0.0.10", "pandora-avatar == 0.0.10",
"rootcanal == 1.11.1 ; python_version>='3.10'", "rootcanal == 1.11.1 ; python_version>='3.10'",
@@ -184,6 +187,10 @@ ignore_missing_imports = true
module = "construct.*" module = "construct.*"
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "ffmpeg.*"
ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "grpc.*" module = "grpc.*"
ignore_missing_imports = true ignore_missing_imports = true