forked from auracaster/bumble_mirror
Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f4721758f | |||
| 8864af4acd | |||
| 8980fb8cc7 | |||
| 6810865670 | |||
| 3e9e06a02c | |||
| ccd12f6591 | |||
| f9a7843f7e | |||
| 210c334db7 | |||
| f297cdfcce | |||
| 5b536d00ab | |||
| b4af46ebd5 | |||
| c08da3193e | |||
| fd4d68e5c0 | |||
| 5d83deffa4 | |||
| 2878cca478 | |||
| 53934716db | |||
| d885d45824 | |||
| b90d0f8710 | |||
| 8ccfc90fe6 | |||
| 92aa7e9e2a | |||
| afc6d19e04 | |||
| c05f073b33 | |||
| 2b4c2a22f4 | |||
| 47fe93a148 | |||
| 6139ca8045 | |||
| 87c76a4a0e | |||
| f7b66db873 | |||
| 0b314bd7f7 | |||
| 9da2e32ad7 | |||
| 93c0875740 | |||
| a286700239 | |||
| 98ed772e8a | |||
| f0b55a4f97 | |||
| b74503d345 | |||
| f911163e49 | |||
| b083cc99ad | |||
| 62a8ced447 | |||
| 81a6b1e097 | |||
| dd090c9e6b | |||
| 11faa48422 | |||
| 55596176c2 | |||
| 4d6822d312 | |||
| 985c365e6d | |||
| af57762227 | |||
| 3575f9030e | |||
| 698d947d85 | |||
| ff6528d2bf | |||
| 72ac75a98d | |||
| 5e3ecb74e4 | |||
| c59be293c8 | |||
| 6d22ed80ec | |||
| ffb3eca68b | |||
| 403a13e4c6 | |||
| ad0f035df5 | |||
| 07f71fc895 | |||
| f47b9178ad | |||
| 4f399249bd | |||
| 9324237828 | |||
| d1033c018a | |||
| 0f29052ade | |||
| 0578e84586 | |||
| 6ab41c466f | |||
| 98a1093ebf | |||
| caf04373f3 | |||
| d4e8526766 | |||
| 515b83a8c7 | |||
| dc18595c8a | |||
| 488bcfe9c6 | |||
| d6cefdff8e | |||
| dc410b14c4 | |||
| 4c49ef9403 | |||
| ba85dcbda5 |
+394
-131
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,63 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import click
|
||||
from bumble.colors import color
|
||||
from bumble.hci import Address
|
||||
from bumble.helpers import generate_irk, verify_rpa_with_irk
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
'''
|
||||
This is a tool for generating IRK, RPA,
|
||||
and verifying IRK/RPA pairs
|
||||
'''
|
||||
|
||||
|
||||
@click.command()
|
||||
def gen_irk() -> None:
|
||||
print(generate_irk().hex())
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
def gen_rpa(irk: str) -> None:
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
rpa = Address.generate_private_address(irk_bytes)
|
||||
print(rpa.to_string(with_type_qualifier=False))
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
@click.argument("rpa", type=str)
|
||||
def verify_rpa(irk: str, rpa: str) -> None:
|
||||
address = Address(rpa)
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
if verify_rpa_with_irk(address, irk_bytes):
|
||||
print(color("Verified", "green"))
|
||||
else:
|
||||
print(color("Not Verified", "red"))
|
||||
|
||||
|
||||
def main():
|
||||
cli.add_command(gen_irk)
|
||||
cli.add_command(gen_rpa)
|
||||
cli.add_command(verify_rpa)
|
||||
cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
+34
-2
@@ -32,10 +32,14 @@ from bumble.hci import (
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
@@ -59,7 +63,7 @@ def command_succeeded(response):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_classic_info(host):
|
||||
async def get_classic_info(host: Host) -> None:
|
||||
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
|
||||
response = await host.send_command(HCI_Read_BD_ADDR_Command())
|
||||
if command_succeeded(response):
|
||||
@@ -80,7 +84,7 @@ async def get_classic_info(host):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_le_info(host):
|
||||
async def get_le_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
|
||||
@@ -136,6 +140,31 @@ async def get_le_info(host):
|
||||
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(transport):
|
||||
print('<<< connecting to HCI...')
|
||||
@@ -168,6 +197,9 @@ async def async_main(transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
|
||||
# Print the list of commands supported by the controller
|
||||
print()
|
||||
print(color('Supported Commands:', 'yellow'))
|
||||
|
||||
+32
-24
@@ -49,14 +49,16 @@ class ServerBridge:
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device: Device) -> None:
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
# Listen for incoming L2CAP channel connections
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
|
||||
),
|
||||
handler=self.on_coc,
|
||||
handler=self.on_channel,
|
||||
)
|
||||
print(
|
||||
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
|
||||
)
|
||||
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
|
||||
|
||||
def on_ble_connection(connection):
|
||||
def on_ble_disconnection(reason):
|
||||
@@ -73,7 +75,7 @@ class ServerBridge:
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
# Called when a new L2CAP connection is established
|
||||
def on_coc(self, l2cap_channel):
|
||||
def on_channel(self, l2cap_channel):
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
|
||||
class Pipe:
|
||||
@@ -83,7 +85,7 @@ class ServerBridge:
|
||||
self.l2cap_channel = l2cap_channel
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_close)
|
||||
l2cap_channel.sink = self.on_coc_sdu
|
||||
l2cap_channel.sink = self.on_channel_sdu
|
||||
|
||||
async def connect_to_tcp(self):
|
||||
# Connect to the TCP server
|
||||
@@ -128,7 +130,7 @@ class ServerBridge:
|
||||
if self.tcp_transport is not None:
|
||||
self.tcp_transport.close()
|
||||
|
||||
def on_coc_sdu(self, sdu):
|
||||
def on_channel_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
if self.tcp_transport is None:
|
||||
print(color('!!! TCP socket not open, dropping', 'red'))
|
||||
@@ -183,7 +185,7 @@ class ClientBridge:
|
||||
peer_name = writer.get_extra_info('peer_name')
|
||||
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
|
||||
|
||||
def on_coc_sdu(sdu):
|
||||
def on_channel_sdu(sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
l2cap_to_tcp_pipe.write(sdu)
|
||||
|
||||
@@ -209,7 +211,7 @@ class ClientBridge:
|
||||
writer.close()
|
||||
return
|
||||
|
||||
l2cap_channel.sink = on_coc_sdu
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_l2cap_close)
|
||||
|
||||
# Start a flow control pipe from L2CAP to TCP
|
||||
@@ -274,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
|
||||
@click.pass_context
|
||||
@click.option('--device-config', help='Device configuration file', required=True)
|
||||
@click.option('--hci-transport', help='HCI transport', required=True)
|
||||
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
|
||||
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
|
||||
@click.option(
|
||||
'--l2cap-coc-max-credits',
|
||||
help='Maximum L2CAP CoC Credits',
|
||||
'--l2cap-max-credits',
|
||||
help='Maximum L2CAP Credits',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=128,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mtu',
|
||||
help='L2CAP CoC MTU',
|
||||
type=click.IntRange(23, 65535),
|
||||
default=1022,
|
||||
'--l2cap-mtu',
|
||||
help='L2CAP MTU',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mps',
|
||||
help='L2CAP CoC MPS',
|
||||
type=click.IntRange(23, 65533),
|
||||
'--l2cap-mps',
|
||||
help='L2CAP MPS',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
def cli(
|
||||
@@ -298,17 +306,17 @@ def cli(
|
||||
device_config,
|
||||
hci_transport,
|
||||
psm,
|
||||
l2cap_coc_max_credits,
|
||||
l2cap_coc_mtu,
|
||||
l2cap_coc_mps,
|
||||
l2cap_max_credits,
|
||||
l2cap_mtu,
|
||||
l2cap_mps,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj['device_config'] = device_config
|
||||
context.obj['hci_transport'] = hci_transport
|
||||
context.obj['psm'] = psm
|
||||
context.obj['max_credits'] = l2cap_coc_max_credits
|
||||
context.obj['mtu'] = l2cap_coc_mtu
|
||||
context.obj['mps'] = l2cap_coc_mps
|
||||
context.obj['max_credits'] = l2cap_max_credits
|
||||
context.obj['mtu'] = l2cap_mtu
|
||||
context.obj['mps'] = l2cap_mps
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
+6
-8
@@ -52,11 +52,13 @@ from bumble.att import (
|
||||
class Waiter:
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, linger=False):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
self.linger = linger
|
||||
|
||||
def terminate(self):
|
||||
self.done.set_result(None)
|
||||
if not self.linger:
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
return await self.done
|
||||
@@ -302,7 +304,7 @@ async def pair(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
Waiter.instance = Waiter()
|
||||
Waiter.instance = Waiter(linger=linger)
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
@@ -396,7 +398,6 @@ async def pair(
|
||||
address_or_name,
|
||||
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
pairing_failure = False
|
||||
|
||||
if not request:
|
||||
try:
|
||||
@@ -405,11 +406,8 @@ async def pair(
|
||||
else:
|
||||
await connection.authenticate()
|
||||
except ProtocolError as error:
|
||||
pairing_failure = True
|
||||
print(color(f'Pairing failed: {error}', 'red'))
|
||||
|
||||
if not linger or pairing_failure:
|
||||
return
|
||||
else:
|
||||
if mode == 'le':
|
||||
# Advertise so that peers can find us and connect
|
||||
@@ -459,7 +457,7 @@ class LogHandler(logging.Handler):
|
||||
help='Enable CTKD',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option('--linger', default=True, is_flag=True, help='Linger after pairing')
|
||||
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
|
||||
@click.option(
|
||||
'--io',
|
||||
type=click.Choice(
|
||||
|
||||
+28
-1
@@ -134,12 +134,14 @@ class Controller:
|
||||
'0000000060000000'
|
||||
) # BR/EDR Not Supported, LE Supported (Controller)
|
||||
self.manufacturer_name = 0xFFFF
|
||||
self.hc_data_packet_length = 27
|
||||
self.hc_total_num_data_packets = 64
|
||||
self.hc_le_data_packet_length = 27
|
||||
self.hc_total_num_le_data_packets = 64
|
||||
self.event_mask = 0
|
||||
self.event_mask_page_2 = 0
|
||||
self.supported_commands = bytes.fromhex(
|
||||
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
|
||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
|
||||
)
|
||||
self.le_event_mask = 0
|
||||
@@ -914,6 +916,19 @@ class Controller:
|
||||
'''
|
||||
return bytes([HCI_SUCCESS]) + self.lmp_features
|
||||
|
||||
def on_hci_read_buffer_size_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
|
||||
'''
|
||||
return struct.pack(
|
||||
'<BHBHH',
|
||||
HCI_SUCCESS,
|
||||
self.hc_data_packet_length,
|
||||
0,
|
||||
self.hc_total_num_data_packets,
|
||||
0,
|
||||
)
|
||||
|
||||
def on_hci_read_bd_addr_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
|
||||
@@ -1263,3 +1278,15 @@ class Controller:
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
|
||||
'''
|
||||
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
|
||||
|
||||
def on_hci_le_setup_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_remove_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
@@ -100,6 +100,16 @@ class EccKey:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def generate_prand() -> bytes:
|
||||
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part E - Table 1.2.
|
||||
'''
|
||||
prand_bytes = secrets.token_bytes(6)
|
||||
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x: bytes, y: bytes) -> bytes:
|
||||
assert len(x) == len(y)
|
||||
|
||||
+180
-46
@@ -437,6 +437,38 @@ class AdvertisingType(IntEnum):
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class LegacyAdvertiser:
|
||||
device: Device
|
||||
advertising_type: AdvertisingType
|
||||
own_address_type: OwnAddressType
|
||||
auto_restart: bool
|
||||
advertising_data: Optional[bytes]
|
||||
scan_response_data: Optional[bytes]
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.device.stop_legacy_advertising()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class ExtendedAdvertiser(CompositeEventEmitter):
|
||||
device: Device
|
||||
handle: int
|
||||
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties
|
||||
own_address_type: OwnAddressType
|
||||
auto_restart: bool
|
||||
advertising_data: Optional[bytes]
|
||||
scan_response_data: Optional[bytes]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.device.stop_extended_advertising(self.handle)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LePhyOptions:
|
||||
# Coded PHY preference
|
||||
@@ -658,6 +690,9 @@ class Connection(CompositeEventEmitter):
|
||||
gatt_client: gatt_client.Client
|
||||
pairing_peer_io_capability: Optional[int]
|
||||
pairing_peer_authentication_requirements: Optional[int]
|
||||
advertiser_after_disconnection: Union[
|
||||
LegacyAdvertiser, ExtendedAdvertiser, None
|
||||
] = None
|
||||
|
||||
@composite_listener
|
||||
class Listener:
|
||||
@@ -1063,7 +1098,8 @@ class Device(CompositeEventEmitter):
|
||||
]
|
||||
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
|
||||
config: DeviceConfiguration
|
||||
extended_advertising_handles: Set[int]
|
||||
legacy_advertiser: Optional[LegacyAdvertiser]
|
||||
extended_advertisers: Dict[int, ExtendedAdvertiser]
|
||||
sco_links: Dict[int, ScoLink]
|
||||
cis_links: Dict[int, CisLink]
|
||||
_pending_cis: Dict[int, Tuple[int, int]]
|
||||
@@ -1141,10 +1177,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
self._host = None
|
||||
self.powered_on = False
|
||||
self.advertising = False
|
||||
self.advertising_type = None
|
||||
self.auto_restart_inquiry = True
|
||||
self.auto_restart_advertising = False
|
||||
self.command_timeout = 10 # seconds
|
||||
self.gatt_server = gatt_server.Server(self)
|
||||
self.sdp_server = sdp.Server(self)
|
||||
@@ -1168,10 +1201,10 @@ class Device(CompositeEventEmitter):
|
||||
self.classic_pending_accepts = {
|
||||
Address.ANY: []
|
||||
} # Futures, by BD address OR [Futures] for Address.ANY
|
||||
self.extended_advertising_handles = set()
|
||||
self.legacy_advertiser = None
|
||||
self.extended_advertisers = {}
|
||||
|
||||
# Own address type cache
|
||||
self.advertising_own_address_type = None
|
||||
self.connect_own_address_type = None
|
||||
|
||||
# Use the initial config or a default
|
||||
@@ -1579,6 +1612,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
return self.host.supports_le_feature(feature_map[phy])
|
||||
|
||||
@deprecated("Please use start_legacy_advertising.")
|
||||
async def start_advertising(
|
||||
self,
|
||||
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
@@ -1586,15 +1620,49 @@ class Device(CompositeEventEmitter):
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = False,
|
||||
) -> None:
|
||||
await self.start_legacy_advertising(
|
||||
advertising_type=advertising_type,
|
||||
target=target,
|
||||
own_address_type=OwnAddressType(own_address_type),
|
||||
auto_restart=auto_restart,
|
||||
)
|
||||
|
||||
async def start_legacy_advertising(
|
||||
self,
|
||||
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
target: Optional[Address] = None,
|
||||
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = False,
|
||||
advertising_data: Optional[bytes] = None,
|
||||
scan_response_data: Optional[bytes] = None,
|
||||
) -> LegacyAdvertiser:
|
||||
"""Starts an legacy advertisement.
|
||||
|
||||
Args:
|
||||
advertising_type: Advertising type passed to HCI_LE_Set_Advertising_Parameters_Command.
|
||||
target: Directed advertising target. Directed type should be set in advertising_type arg.
|
||||
own_address_type: own address type to use in the advertising.
|
||||
auto_restart: whether the advertisement will be restarted after disconnection.
|
||||
scan_response_data: raw scan response.
|
||||
advertising_data: raw advertising data.
|
||||
|
||||
Returns:
|
||||
LegacyAdvertiser object containing the metadata of advertisement.
|
||||
"""
|
||||
if self.extended_advertisers:
|
||||
logger.warning(
|
||||
'Trying to start Legacy and Extended Advertising at the same time!'
|
||||
)
|
||||
|
||||
# If we're advertising, stop first
|
||||
if self.advertising:
|
||||
if self.legacy_advertiser:
|
||||
await self.stop_advertising()
|
||||
|
||||
# Set/update the advertising data if the advertising type allows it
|
||||
if advertising_type.has_data:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Data_Command(
|
||||
advertising_data=self.advertising_data
|
||||
advertising_data=advertising_data or self.advertising_data or b''
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
@@ -1603,7 +1671,9 @@ class Device(CompositeEventEmitter):
|
||||
if advertising_type.is_scannable:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Response_Data_Command(
|
||||
scan_response_data=self.scan_response_data
|
||||
scan_response_data=scan_response_data
|
||||
or self.scan_response_data
|
||||
or b''
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
@@ -1640,45 +1710,57 @@ class Device(CompositeEventEmitter):
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.advertising_type = advertising_type
|
||||
self.advertising_own_address_type = own_address_type
|
||||
self.advertising = True
|
||||
self.auto_restart_advertising = auto_restart
|
||||
self.legacy_advertiser = LegacyAdvertiser(
|
||||
device=self,
|
||||
advertising_type=advertising_type,
|
||||
own_address_type=own_address_type,
|
||||
auto_restart=auto_restart,
|
||||
advertising_data=advertising_data,
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
return self.legacy_advertiser
|
||||
|
||||
@deprecated("Please use stop_legacy_advertising.")
|
||||
async def stop_advertising(self) -> None:
|
||||
await self.stop_legacy_advertising()
|
||||
|
||||
async def stop_legacy_advertising(self) -> None:
|
||||
# Disable advertising
|
||||
if self.advertising:
|
||||
if self.legacy_advertiser:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.advertising_type = None
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
self.auto_restart_advertising = False
|
||||
self.legacy_advertiser = None
|
||||
|
||||
@experimental('Extended Advertising is still experimental - Might be changed soon.')
|
||||
async def start_extended_advertising(
|
||||
self,
|
||||
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING,
|
||||
target: Address = Address.ANY,
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
scan_response: Optional[bytes] = None,
|
||||
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = True,
|
||||
advertising_data: Optional[bytes] = None,
|
||||
) -> int:
|
||||
scan_response_data: Optional[bytes] = None,
|
||||
) -> ExtendedAdvertiser:
|
||||
"""Starts an extended advertising set.
|
||||
|
||||
Args:
|
||||
advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command
|
||||
target: Directed advertising target. Directed property should be set in advertising_properties arg.
|
||||
own_address_type: own address type to use in the advertising.
|
||||
scan_response: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
|
||||
auto_restart: whether the advertisement will be restarted after disconnection.
|
||||
advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent.
|
||||
scan_response_data: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
|
||||
|
||||
Returns:
|
||||
Handle of the new advertising set.
|
||||
ExtendedAdvertiser object containing the metadata of advertisement.
|
||||
"""
|
||||
if self.legacy_advertiser:
|
||||
logger.warning(
|
||||
'Trying to start Legacy and Extended Advertising at the same time!'
|
||||
)
|
||||
|
||||
adv_handle = -1
|
||||
# Find a free handle
|
||||
@@ -1686,7 +1768,7 @@ class Device(CompositeEventEmitter):
|
||||
DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE,
|
||||
DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1,
|
||||
):
|
||||
if i not in self.extended_advertising_handles:
|
||||
if i not in self.extended_advertisers:
|
||||
adv_handle = i
|
||||
break
|
||||
|
||||
@@ -1733,13 +1815,13 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
|
||||
# Set the scan response if present
|
||||
if scan_response is not None:
|
||||
if scan_response_data is not None:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Extended_Scan_Response_Data_Command(
|
||||
advertising_handle=adv_handle,
|
||||
operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA,
|
||||
fragment_preference=0x01, # Should not fragment
|
||||
scan_response_data=scan_response,
|
||||
scan_response_data=scan_response_data,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
@@ -1774,8 +1856,16 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
raise error
|
||||
|
||||
self.extended_advertising_handles.add(adv_handle)
|
||||
return adv_handle
|
||||
advertiser = self.extended_advertisers[adv_handle] = ExtendedAdvertiser(
|
||||
device=self,
|
||||
handle=adv_handle,
|
||||
advertising_properties=advertising_properties,
|
||||
own_address_type=own_address_type,
|
||||
auto_restart=auto_restart,
|
||||
advertising_data=advertising_data,
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
return advertiser
|
||||
|
||||
@experimental('Extended Advertising is still experimental - Might be changed soon.')
|
||||
async def stop_extended_advertising(self, adv_handle: int) -> None:
|
||||
@@ -1799,11 +1889,11 @@ class Device(CompositeEventEmitter):
|
||||
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle),
|
||||
check_result=True,
|
||||
)
|
||||
self.extended_advertising_handles.remove(adv_handle)
|
||||
del self.extended_advertisers[adv_handle]
|
||||
|
||||
@property
|
||||
def is_advertising(self):
|
||||
return self.advertising
|
||||
return self.legacy_advertiser or self.extended_advertisers
|
||||
|
||||
async def start_scanning(
|
||||
self,
|
||||
@@ -3144,13 +3234,18 @@ class Device(CompositeEventEmitter):
|
||||
# Guess which own address type is used for this connection.
|
||||
# This logic is somewhat correct but may need to be improved
|
||||
# when multiple advertising are run simultaneously.
|
||||
advertiser = None
|
||||
if self.connect_own_address_type is not None:
|
||||
own_address_type = self.connect_own_address_type
|
||||
elif self.legacy_advertiser:
|
||||
own_address_type = self.legacy_advertiser.own_address_type
|
||||
# Store advertiser for restarting - it's only required for legacy, since
|
||||
# extended advertisement produces HCI_Advertising_Set_Terminated.
|
||||
if self.legacy_advertiser.auto_restart:
|
||||
advertiser = self.legacy_advertiser
|
||||
else:
|
||||
own_address_type = self.advertising_own_address_type
|
||||
|
||||
# We are no longer advertising
|
||||
self.advertising = False
|
||||
# For extended advertisement, determining own address type later.
|
||||
own_address_type = OwnAddressType.RANDOM
|
||||
|
||||
if own_address_type in (
|
||||
OwnAddressType.PUBLIC,
|
||||
@@ -3172,6 +3267,7 @@ class Device(CompositeEventEmitter):
|
||||
connection_parameters,
|
||||
ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY),
|
||||
)
|
||||
connection.advertiser_after_disconnection = advertiser
|
||||
self.connections[connection_handle] = connection
|
||||
|
||||
# If supported, read which PHY we're connected with before
|
||||
@@ -3203,10 +3299,10 @@ class Device(CompositeEventEmitter):
|
||||
# For directed advertising, this means a timeout
|
||||
if (
|
||||
transport == BT_LE_TRANSPORT
|
||||
and self.advertising
|
||||
and self.advertising_type.is_directed
|
||||
and self.legacy_advertiser
|
||||
and self.legacy_advertiser.advertising_type.is_directed
|
||||
):
|
||||
self.advertising = False
|
||||
self.legacy_advertiser = None
|
||||
|
||||
# Notify listeners
|
||||
error = core.ConnectionError(
|
||||
@@ -3268,16 +3364,30 @@ class Device(CompositeEventEmitter):
|
||||
self.gatt_server.on_disconnection(connection)
|
||||
|
||||
# Restart advertising if auto-restart is enabled
|
||||
if self.auto_restart_advertising:
|
||||
if advertiser := connection.advertiser_after_disconnection:
|
||||
logger.debug('restarting advertising')
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_advertising(
|
||||
advertising_type=self.advertising_type, # type: ignore[arg-type]
|
||||
own_address_type=self.advertising_own_address_type, # type: ignore[arg-type]
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
if isinstance(advertiser, LegacyAdvertiser):
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_legacy_advertising(
|
||||
advertising_type=advertiser.advertising_type,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
elif isinstance(advertiser, ExtendedAdvertiser):
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_extended_advertising(
|
||||
advertising_properties=advertiser.advertising_properties,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
elif sco_link := self.sco_links.pop(connection_handle, None):
|
||||
sco_link.emit('disconnection', reason)
|
||||
elif cis_link := self.cis_links.pop(connection_handle, None):
|
||||
@@ -3600,6 +3710,30 @@ class Device(CompositeEventEmitter):
|
||||
if sco_link := self.sco_links.get(sco_handle, None):
|
||||
sco_link.emit('pdu', packet)
|
||||
|
||||
# [LE only]
|
||||
@host_event_handler
|
||||
@experimental('Only for testing')
|
||||
def on_advertising_set_termination(
|
||||
self,
|
||||
status: int,
|
||||
advertising_handle: int,
|
||||
connection_handle: int,
|
||||
) -> None:
|
||||
if status == HCI_SUCCESS:
|
||||
connection = self.lookup_connection(connection_handle)
|
||||
if advertiser := self.extended_advertisers.pop(advertising_handle, None):
|
||||
if connection:
|
||||
if advertiser.auto_restart:
|
||||
connection.advertiser_after_disconnection = advertiser
|
||||
if advertiser.own_address_type in (
|
||||
OwnAddressType.PUBLIC,
|
||||
OwnAddressType.RESOLVABLE_OR_PUBLIC,
|
||||
):
|
||||
connection.self_address = self.public_address
|
||||
else:
|
||||
connection.self_address = self.random_address
|
||||
advertiser.emit('termination', status)
|
||||
|
||||
# [LE only]
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
|
||||
+28
-32
@@ -19,12 +19,17 @@ like loading firmware after a cold start.
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import pathlib
|
||||
import platform
|
||||
from . import rtk
|
||||
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from . import rtk
|
||||
from .common import Driver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -32,40 +37,31 @@ from . import rtk
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_driver_for_host(host):
|
||||
"""Probe all known diver classes until one returns a valid instance for a host,
|
||||
or none is found.
|
||||
async def get_driver_for_host(host: Host) -> Optional[Driver]:
|
||||
"""Probe diver classes until one returns a valid instance for a host, or none is
|
||||
found.
|
||||
If a "driver" HCI metadata entry is present, only that driver class will be probed.
|
||||
"""
|
||||
if driver := await rtk.Driver.for_host(host):
|
||||
logger.debug("Instantiated RTK driver")
|
||||
return driver
|
||||
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver}
|
||||
probe_list: Iterable[str]
|
||||
if driver_name := host.hci_metadata.get("driver"):
|
||||
# Only probe a single driver
|
||||
probe_list = [driver_name]
|
||||
else:
|
||||
# Probe all drivers
|
||||
probe_list = driver_classes.keys()
|
||||
|
||||
for driver_name in probe_list:
|
||||
if driver_class := driver_classes.get(driver_name):
|
||||
logger.debug(f"Probing driver class: {driver_name}")
|
||||
if driver := await driver_class.for_host(host):
|
||||
logger.debug(f"Instantiated {driver_name} driver")
|
||||
return driver
|
||||
else:
|
||||
logger.debug(f"Skipping unknown driver class: {driver_name}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Common types for drivers.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
+11
-4
@@ -41,7 +41,7 @@ from bumble.hci import (
|
||||
HCI_Reset_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
|
||||
from bumble.drivers import common
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -285,7 +285,7 @@ class Firmware:
|
||||
)
|
||||
|
||||
|
||||
class Driver:
|
||||
class Driver(common.Driver):
|
||||
@dataclass
|
||||
class DriverInfo:
|
||||
rom: int
|
||||
@@ -470,8 +470,12 @@ class Driver:
|
||||
logger.debug("USB metadata not found")
|
||||
return False
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id", None)
|
||||
product_id = host.hci_metadata.get("product_id", None)
|
||||
if host.hci_metadata.get('driver') == 'rtk':
|
||||
# Forced driver
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
@@ -486,6 +490,9 @@ class Driver:
|
||||
|
||||
@classmethod
|
||||
async def driver_info_for_host(cls, host):
|
||||
await host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
host.ready = True # Needed to let the host know the controller is ready.
|
||||
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
)
|
||||
|
||||
+5
-2
@@ -368,9 +368,12 @@ class TemplateService(Service):
|
||||
UUID: UUID
|
||||
|
||||
def __init__(
|
||||
self, characteristics: List[Characteristic], primary: bool = True
|
||||
self,
|
||||
characteristics: List[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: List[Service] = [],
|
||||
) -> None:
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
super().__init__(self.UUID, characteristics, primary, included_services)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -961,7 +961,7 @@ class Server(EventEmitter):
|
||||
try:
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! ignoring exception: {error}')
|
||||
logger.exception(f'!!! ignoring exception: {error}')
|
||||
|
||||
def on_att_handle_value_confirmation(self, connection, _confirmation):
|
||||
'''
|
||||
|
||||
+80
-19
@@ -21,9 +21,11 @@ import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
from typing import Any, Dict, Callable, Optional, Type, Union, List
|
||||
|
||||
from bumble import crypto
|
||||
from .colors import color
|
||||
from .core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
@@ -728,6 +730,19 @@ HCI_LE_PHY_TYPE_TO_BIT = {
|
||||
HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT
|
||||
}
|
||||
|
||||
|
||||
class Phy(enum.IntEnum):
|
||||
LE_1M = 0x01
|
||||
LE_2M = 0x02
|
||||
LE_CODED = 0x03
|
||||
|
||||
|
||||
class PhyBit(enum.IntFlag):
|
||||
LE_1M = 0b00000001
|
||||
LE_2M = 0b00000010
|
||||
LE_CODED = 0b00000100
|
||||
|
||||
|
||||
# Connection Parameters
|
||||
HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25
|
||||
HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25
|
||||
@@ -1868,6 +1883,43 @@ class Address:
|
||||
address_type = data[offset - 1]
|
||||
return Address.parse_address_with_type(data, offset, address_type)
|
||||
|
||||
@classmethod
|
||||
def generate_static_address(cls) -> Address:
|
||||
'''Generates Random Static Address, with the 2 most significant bits of 0b11.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part B - Table 1.2.
|
||||
'''
|
||||
address_bytes = secrets.token_bytes(6)
|
||||
address_bytes = address_bytes[:5] + bytes([address_bytes[5] | 0b11000000])
|
||||
return Address(
|
||||
address=address_bytes, address_type=Address.RANDOM_DEVICE_ADDRESS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_private_address(cls, irk: bytes = b'') -> Address:
|
||||
'''Generates Random Private MAC Address.
|
||||
|
||||
If IRK is present, a Resolvable Private Address, with the 2 most significant
|
||||
bits of 0b01 will be generated. Otherwise, a Non-resolvable Private Address,
|
||||
with the 2 most significant bits of 0b00 will be generated.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part B - Table 1.2.
|
||||
|
||||
Args:
|
||||
irk: Local Identity Resolving Key(IRK), in little-endian. If not set, a
|
||||
non-resolvable address will be generated.
|
||||
'''
|
||||
if irk:
|
||||
prand = crypto.generate_prand()
|
||||
address_bytes = crypto.ah(irk, prand) + prand
|
||||
else:
|
||||
address_bytes = secrets.token_bytes(6)
|
||||
address_bytes = address_bytes[:5] + bytes([address_bytes[5] & 0b00111111])
|
||||
|
||||
return Address(
|
||||
address=address_bytes, address_type=Address.RANDOM_DEVICE_ADDRESS
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS
|
||||
):
|
||||
@@ -1963,25 +2015,15 @@ Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OwnAddressType:
|
||||
class OwnAddressType(enum.IntEnum):
|
||||
PUBLIC = 0
|
||||
RANDOM = 1
|
||||
RESOLVABLE_OR_PUBLIC = 2
|
||||
RESOLVABLE_OR_RANDOM = 3
|
||||
|
||||
TYPE_NAMES = {
|
||||
PUBLIC: 'PUBLIC',
|
||||
RANDOM: 'RANDOM',
|
||||
RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC',
|
||||
RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def type_name(type_id):
|
||||
return name_or_number(OwnAddressType.TYPE_NAMES, type_id)
|
||||
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)}
|
||||
@classmethod
|
||||
def type_spec(cls):
|
||||
return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -3374,7 +3416,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
|
||||
),
|
||||
},
|
||||
),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('advertising_channel_map', 1),
|
||||
@@ -3467,7 +3509,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command):
|
||||
('le_scan_type', 1),
|
||||
('le_scan_interval', 2),
|
||||
('le_scan_window', 2),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('scanning_filter_policy', 1),
|
||||
]
|
||||
)
|
||||
@@ -3506,7 +3548,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command):
|
||||
('initiator_filter_policy', 1),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('connection_interval_min', 2),
|
||||
('connection_interval_max', 2),
|
||||
('max_latency', 2),
|
||||
@@ -3913,7 +3955,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
|
||||
),
|
||||
},
|
||||
),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('advertising_filter_policy', 1),
|
||||
@@ -4309,7 +4351,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
|
||||
('initiator_filter_policy:', self.initiator_filter_policy),
|
||||
(
|
||||
'own_address_type: ',
|
||||
OwnAddressType.type_name(self.own_address_type),
|
||||
OwnAddressType(self.own_address_type).name,
|
||||
),
|
||||
(
|
||||
'peer_address_type: ',
|
||||
@@ -4551,6 +4593,10 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command):
|
||||
See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command
|
||||
'''
|
||||
|
||||
class Direction(enum.IntEnum):
|
||||
HOST_TO_CONTROLLER = 0x00
|
||||
CONTROLLER_TO_HOST = 0x01
|
||||
|
||||
connection_handle: int
|
||||
data_path_direction: int
|
||||
data_path_id: int
|
||||
@@ -5190,6 +5236,21 @@ HCI_LE_Meta_Event.subevent_classes[
|
||||
] = HCI_LE_Extended_Advertising_Report_Event
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event(
|
||||
[
|
||||
('status', 1),
|
||||
('advertising_handle', 1),
|
||||
('connection_handle', 2),
|
||||
('number_completed_extended_advertising_events', 1),
|
||||
]
|
||||
)
|
||||
class HCI_LE_Advertising_Set_Terminated_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.65.18 LE Advertising Set Terminated Event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)])
|
||||
class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event):
|
||||
|
||||
@@ -37,6 +37,7 @@ from bumble.l2cap import (
|
||||
L2CAP_Connection_Response,
|
||||
)
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_DISCONNECTION_COMPLETE_EVENT,
|
||||
@@ -48,6 +49,7 @@ from bumble.hci import (
|
||||
)
|
||||
from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
|
||||
from bumble.sdp import SDP_PDU, SDP_PSM
|
||||
from bumble import crypto
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -232,3 +234,15 @@ class PacketTracer:
|
||||
)
|
||||
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
|
||||
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
|
||||
|
||||
|
||||
def generate_irk() -> bytes:
|
||||
return crypto.r()
|
||||
|
||||
|
||||
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
|
||||
rpa_bytes = bytes(rpa)
|
||||
prand_given = rpa_bytes[3:]
|
||||
hash_given = rpa_bytes[:3]
|
||||
hash_local = crypto.ah(irk, prand_given)
|
||||
return hash_local[:3] == hash_given
|
||||
|
||||
+314
-93
@@ -19,16 +19,17 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble import l2cap, device
|
||||
from bumble.colors import color
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
from .hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -60,6 +61,7 @@ class Message:
|
||||
NOT_READY = 0x01
|
||||
ERR_INVALID_REPORT_ID = 0x02
|
||||
ERR_UNSUPPORTED_REQUEST = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
ERR_UNKNOWN = 0x0E
|
||||
ERR_FATAL = 0x0F
|
||||
|
||||
@@ -101,13 +103,14 @@ class GetReportMessage(Message):
|
||||
def __bytes__(self) -> bytes:
|
||||
packet_bytes = bytearray()
|
||||
packet_bytes.append(self.report_id)
|
||||
packet_bytes.extend(
|
||||
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
|
||||
)
|
||||
if self.report_type == Message.ReportType.OTHER_REPORT:
|
||||
if self.buffer_size == 0:
|
||||
return self.header(self.report_type) + packet_bytes
|
||||
else:
|
||||
return self.header(0x08 | self.report_type) + packet_bytes
|
||||
return (
|
||||
self.header(0x08 | self.report_type)
|
||||
+ packet_bytes
|
||||
+ struct.pack("<H", self.buffer_size)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -120,6 +123,16 @@ class SetReportMessage(Message):
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendControlData(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetProtocolMessage(Message):
|
||||
message_type = Message.MessageType.GET_PROTOCOL
|
||||
@@ -161,31 +174,47 @@ class VirtualCableUnplug(Message):
|
||||
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
|
||||
|
||||
|
||||
# Device sends input report, host sends output report.
|
||||
@dataclass
|
||||
class SendData(Message):
|
||||
data: bytes
|
||||
report_type: int
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendHandshakeMessage(Message):
|
||||
result_code: int
|
||||
message_type = Message.MessageType.HANDSHAKE
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.result_code)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel]
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel]
|
||||
class HID(ABC, EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
|
||||
connection: Optional[device.Connection] = None
|
||||
|
||||
def __init__(self, device: Device, connection: Connection) -> None:
|
||||
class Role(enum.IntEnum):
|
||||
HOST = 0x00
|
||||
DEVICE = 0x01
|
||||
|
||||
def __init__(self, device: device.Device, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.remote_device_bd_address: Optional[Address] = None
|
||||
self.device = device
|
||||
self.connection = connection
|
||||
|
||||
self.l2cap_ctrl_channel = None
|
||||
self.l2cap_intr_channel = None
|
||||
self.role = role
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection)
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
|
||||
|
||||
device.on('connection', self.on_device_connection)
|
||||
|
||||
async def connect_control_channel(self) -> None:
|
||||
# Create a new L2CAP connection - control channel
|
||||
@@ -229,9 +258,18 @@ class Host(EventEmitter):
|
||||
self.l2cap_ctrl_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
def on_device_connection(self, connection: device.Connection) -> None:
|
||||
self.connection = connection
|
||||
self.remote_device_bd_address = connection.peer_address
|
||||
connection.on('disconnection', self.on_device_disconnection)
|
||||
|
||||
def on_device_disconnection(self, reason: int) -> None:
|
||||
self.connection = None
|
||||
|
||||
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
@@ -242,63 +280,20 @@ class Host(EventEmitter):
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
# Here we will receive all kinds of packets, parse and then call respective callbacks
|
||||
message_type = pdu[0] >> 4
|
||||
param = pdu[0] & 0x0F
|
||||
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend', pdu)
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend', pdu)
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = None
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('data', pdu)
|
||||
self.l2cap_intr_channel = None
|
||||
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
|
||||
|
||||
@abstractmethod
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
pass
|
||||
|
||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||
self.emit("data", pdu)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes):
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self):
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int):
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
self.emit("interrupt_data", pdu)
|
||||
|
||||
def send_pdu_on_ctrl(self, msg: bytes) -> None:
|
||||
assert self.l2cap_ctrl_channel
|
||||
@@ -308,26 +303,252 @@ class Host(EventEmitter):
|
||||
assert self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel.send_pdu(msg)
|
||||
|
||||
def send_data(self, data):
|
||||
msg = SendData(data)
|
||||
def send_data(self, data: bytes) -> None:
|
||||
if self.role == HID.Role.HOST:
|
||||
report_type = Message.ReportType.OUTPUT_REPORT
|
||||
else:
|
||||
report_type = Message.ReportType.INPUT_REPORT
|
||||
msg = SendData(data, report_type)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
if self.l2cap_intr_channel is not None:
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
|
||||
def suspend(self):
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(msg)
|
||||
|
||||
def exit_suspend(self):
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(msg)
|
||||
|
||||
def virtual_cable_unplug(self):
|
||||
def virtual_cable_unplug(self) -> None:
|
||||
msg = VirtualCableUnplug()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(msg)
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Device(HID):
|
||||
class GetSetReturn(enum.IntEnum):
|
||||
FAILURE = 0x00
|
||||
REPORT_ID_NOT_FOUND = 0x01
|
||||
ERR_UNSUPPORTED_REQUEST = 0x02
|
||||
ERR_UNKNOWN = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
SUCCESS = 0xFF
|
||||
|
||||
class GetSetStatus:
|
||||
def __init__(self) -> None:
|
||||
self.data = bytearray()
|
||||
self.status = 0
|
||||
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.DEVICE)
|
||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
|
||||
if message_type == Message.MessageType.GET_REPORT:
|
||||
logger.debug('<<< HID GET REPORT')
|
||||
self.handle_get_report(pdu)
|
||||
elif message_type == Message.MessageType.SET_REPORT:
|
||||
logger.debug('<<< HID SET REPORT')
|
||||
self.handle_set_report(pdu)
|
||||
elif message_type == Message.MessageType.GET_PROTOCOL:
|
||||
logger.debug('<<< HID GET PROTOCOL')
|
||||
self.handle_get_protocol(pdu)
|
||||
elif message_type == Message.MessageType.SET_PROTOCOL:
|
||||
logger.debug('<<< HID SET PROTOCOL')
|
||||
self.handle_set_protocol(pdu)
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend')
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend')
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def send_handshake_message(self, result_code: int) -> None:
|
||||
msg = SendHandshakeMessage(result_code)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def send_control_data(self, report_type: int, data: bytes):
|
||||
msg = SendControlData(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def handle_get_report(self, pdu: bytes):
|
||||
if self.get_report_cb is None:
|
||||
logger.debug("GetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
buffer_flag = (pdu[0] & 0x08) >> 3
|
||||
report_id = pdu[1]
|
||||
logger.debug(f"buffer_flag: {buffer_flag}")
|
||||
if buffer_flag == 1:
|
||||
buffer_size = (pdu[3] << 8) | pdu[2]
|
||||
else:
|
||||
buffer_size = 0
|
||||
|
||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.FAILURE:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||
data = bytearray()
|
||||
data.append(report_id)
|
||||
data.extend(ret.data)
|
||||
if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr]
|
||||
self.send_control_data(report_type=report_type, data=data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
||||
self.get_report_cb = cb
|
||||
logger.debug("GetReport callback registered successfully")
|
||||
|
||||
def handle_set_report(self, pdu: bytes):
|
||||
if self.set_report_cb is None:
|
||||
logger.debug("SetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
report_id = pdu[1]
|
||||
report_data = pdu[2:]
|
||||
report_size = len(report_data) + 1
|
||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_report_cb(
|
||||
self, cb: Callable[[int, int, int, bytes], None]
|
||||
) -> None:
|
||||
self.set_report_cb = cb
|
||||
logger.debug("SetReport callback registered successfully")
|
||||
|
||||
def handle_get_protocol(self, pdu: bytes):
|
||||
if self.get_protocol_cb is None:
|
||||
logger.debug("GetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.get_protocol_cb()
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
||||
self.get_protocol_cb = cb
|
||||
logger.debug("GetProtocol callback registered successfully")
|
||||
|
||||
def handle_set_protocol(self, pdu: bytes):
|
||||
if self.set_protocol_cb is None:
|
||||
logger.debug("SetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
||||
self.set_protocol_cb = cb
|
||||
logger.debug("SetProtocol callback registered successfully")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(HID):
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.HOST)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes) -> None:
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self) -> None:
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int) -> None:
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def suspend(self) -> None:
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def exit_suspend(self) -> None:
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
|
||||
+131
-90
@@ -21,7 +21,7 @@ import collections
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast
|
||||
from typing import Any, Awaitable, Callable, Deque, Dict, Optional, cast, TYPE_CHECKING
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.l2cap import L2CAP_PDU
|
||||
@@ -91,16 +91,49 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
# fmt: off
|
||||
class AclPacketQueue:
|
||||
max_packet_size: int
|
||||
|
||||
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
|
||||
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
|
||||
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
|
||||
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
|
||||
def __init__(
|
||||
self,
|
||||
max_packet_size: int,
|
||||
max_in_flight: int,
|
||||
send: Callable[[HCI_Packet], None],
|
||||
) -> None:
|
||||
self.max_packet_size = max_packet_size
|
||||
self.max_in_flight = max_in_flight
|
||||
self.in_flight = 0
|
||||
self.send = send
|
||||
self.packets: Deque[HCI_AclDataPacket] = collections.deque()
|
||||
|
||||
# fmt: on
|
||||
def enqueue(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.packets.appendleft(packet)
|
||||
self.check_queue()
|
||||
|
||||
if self.packets:
|
||||
logger.debug(
|
||||
f'{self.in_flight} ACL packets in flight, '
|
||||
f'{len(self.packets)} in queue'
|
||||
)
|
||||
|
||||
def check_queue(self) -> None:
|
||||
while self.packets and self.in_flight < self.max_in_flight:
|
||||
packet = self.packets.pop()
|
||||
self.send(packet)
|
||||
self.in_flight += 1
|
||||
|
||||
def on_packets_completed(self, packet_count: int) -> None:
|
||||
if packet_count > self.in_flight:
|
||||
logger.warning(
|
||||
color(
|
||||
'!!! {packet_count} completed but only '
|
||||
f'{self.in_flight} in flight'
|
||||
)
|
||||
)
|
||||
packet_count = self.in_flight
|
||||
|
||||
self.in_flight -= packet_count
|
||||
self.check_queue()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -111,6 +144,13 @@ class Connection:
|
||||
self.peer_address = peer_address
|
||||
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.transport = transport
|
||||
acl_packet_queue: Optional[AclPacketQueue] = (
|
||||
host.le_acl_packet_queue
|
||||
if transport == BT_LE_TRANSPORT
|
||||
else host.acl_packet_queue
|
||||
)
|
||||
assert acl_packet_queue
|
||||
self.acl_packet_queue = acl_packet_queue
|
||||
|
||||
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.assembler.feed_packet(packet)
|
||||
@@ -123,8 +163,10 @@ class Connection:
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(AbortableEventEmitter):
|
||||
connections: Dict[int, Connection]
|
||||
acl_packet_queue: collections.deque[HCI_AclDataPacket]
|
||||
hci_sink: TransportSink
|
||||
acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
le_acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
hci_sink: Optional[TransportSink] = None
|
||||
hci_metadata: Dict[str, Any]
|
||||
long_term_key_provider: Optional[
|
||||
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
|
||||
]
|
||||
@@ -137,18 +179,11 @@ class Host(AbortableEventEmitter):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hci_metadata = None
|
||||
self.hci_metadata = {}
|
||||
self.ready = False # True when we can accept incoming packets
|
||||
self.reset_done = False
|
||||
self.connections = {} # Connections, by connection handle
|
||||
self.pending_command = None
|
||||
self.pending_response = None
|
||||
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
|
||||
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
|
||||
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
|
||||
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
|
||||
self.acl_packet_queue = collections.deque()
|
||||
self.acl_packets_in_flight = 0
|
||||
self.local_version = None
|
||||
self.local_supported_commands = bytes(64)
|
||||
self.local_le_features = 0
|
||||
@@ -162,10 +197,7 @@ class Host(AbortableEventEmitter):
|
||||
|
||||
# Connect to the source and sink if specified
|
||||
if controller_source:
|
||||
controller_source.set_packet_sink(self)
|
||||
self.hci_metadata = getattr(
|
||||
controller_source, 'metadata', self.hci_metadata
|
||||
)
|
||||
self.set_packet_source(controller_source)
|
||||
if controller_sink:
|
||||
self.set_packet_sink(controller_sink)
|
||||
|
||||
@@ -200,17 +232,21 @@ class Host(AbortableEventEmitter):
|
||||
self.ready = False
|
||||
await self.flush()
|
||||
|
||||
await self.send_command(HCI_Reset_Command(), check_result=True)
|
||||
self.ready = True
|
||||
|
||||
# Instantiate and init a driver for the host if needed.
|
||||
# NOTE: we don't keep a reference to the driver here, because we don't
|
||||
# currently have a need for the driver later on. But if the driver interface
|
||||
# evolves, it may be required, then, to store a reference to the driver in
|
||||
# an object property.
|
||||
reset_needed = True
|
||||
if driver_factory is not None:
|
||||
if driver := await driver_factory(self):
|
||||
await driver.init_controller()
|
||||
reset_needed = False
|
||||
|
||||
# Send a reset command unless a driver has already done so.
|
||||
if reset_needed:
|
||||
await self.send_command(HCI_Reset_Command(), check_result=True)
|
||||
self.ready = True
|
||||
|
||||
response = await self.send_command(
|
||||
HCI_Read_Local_Supported_Commands_Command(), check_result=True
|
||||
@@ -253,46 +289,54 @@ class Host(AbortableEventEmitter):
|
||||
response = await self.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
self.hc_acl_data_packet_length = (
|
||||
hc_acl_data_packet_length = (
|
||||
response.return_parameters.hc_acl_data_packet_length
|
||||
)
|
||||
self.hc_total_num_acl_data_packets = (
|
||||
hc_total_num_acl_data_packets = (
|
||||
response.return_parameters.hc_total_num_acl_data_packets
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'HCI ACL flow control: '
|
||||
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
|
||||
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
|
||||
f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
|
||||
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
|
||||
)
|
||||
|
||||
self.acl_packet_queue = AclPacketQueue(
|
||||
max_packet_size=hc_acl_data_packet_length,
|
||||
max_in_flight=hc_total_num_acl_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
hc_le_acl_data_packet_length = 0
|
||||
hc_total_num_le_acl_data_packets = 0
|
||||
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await self.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
self.hc_le_acl_data_packet_length = (
|
||||
hc_le_acl_data_packet_length = (
|
||||
response.return_parameters.hc_le_acl_data_packet_length
|
||||
)
|
||||
self.hc_total_num_le_acl_data_packets = (
|
||||
hc_total_num_le_acl_data_packets = (
|
||||
response.return_parameters.hc_total_num_le_acl_data_packets
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'HCI LE ACL flow control: '
|
||||
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
|
||||
'hc_total_num_le_acl_data_packets='
|
||||
f'{self.hc_total_num_le_acl_data_packets}'
|
||||
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
|
||||
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
|
||||
)
|
||||
|
||||
if (
|
||||
response.return_parameters.hc_le_acl_data_packet_length == 0
|
||||
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
|
||||
):
|
||||
# LE and Classic share the same values
|
||||
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
|
||||
self.hc_total_num_le_acl_data_packets = (
|
||||
self.hc_total_num_acl_data_packets
|
||||
)
|
||||
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
|
||||
# LE and Classic share the same queue
|
||||
self.le_acl_packet_queue = self.acl_packet_queue
|
||||
else:
|
||||
# Create a separate queue for LE
|
||||
self.le_acl_packet_queue = AclPacketQueue(
|
||||
max_packet_size=hc_le_acl_data_packet_length,
|
||||
max_in_flight=hc_total_num_le_acl_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
if self.supports_command(
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
|
||||
@@ -313,29 +357,31 @@ class Host(AbortableEventEmitter):
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_done = True
|
||||
|
||||
@property
|
||||
def controller(self) -> TransportSink:
|
||||
def controller(self) -> Optional[TransportSink]:
|
||||
return self.hci_sink
|
||||
|
||||
@controller.setter
|
||||
def controller(self, controller):
|
||||
def controller(self, controller) -> None:
|
||||
self.set_packet_sink(controller)
|
||||
if controller:
|
||||
controller.set_packet_sink(self)
|
||||
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
|
||||
self.hci_sink = sink
|
||||
|
||||
def set_packet_source(self, source: TransportSource) -> None:
|
||||
source.set_packet_sink(self)
|
||||
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
|
||||
|
||||
def send_hci_packet(self, packet: HCI_Packet) -> None:
|
||||
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
|
||||
if self.snooper:
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
self.hci_sink.on_packet(bytes(packet))
|
||||
if self.hci_sink:
|
||||
self.hci_sink.on_packet(bytes(packet))
|
||||
|
||||
async def send_command(self, command, check_result=False):
|
||||
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
|
||||
|
||||
# Wait until we can send (only one pending command at a time)
|
||||
async with self.command_semaphore:
|
||||
assert self.pending_command is None
|
||||
@@ -383,6 +429,17 @@ class Host(AbortableEventEmitter):
|
||||
asyncio.create_task(send_command(command))
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
if not (connection := self.connections.get(connection_handle)):
|
||||
logger.warning(f'connection 0x{connection_handle:04X} not found')
|
||||
return
|
||||
packet_queue = connection.acl_packet_queue
|
||||
if packet_queue is None:
|
||||
logger.warning(
|
||||
f'no ACL packet queue for connection 0x{connection_handle:04X}'
|
||||
)
|
||||
return
|
||||
|
||||
# Create a PDU
|
||||
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
|
||||
|
||||
# Send the data to the controller via ACL packets
|
||||
@@ -390,8 +447,7 @@ class Host(AbortableEventEmitter):
|
||||
offset = 0
|
||||
pb_flag = 0
|
||||
while bytes_remaining:
|
||||
# TODO: support different LE/Classic lengths
|
||||
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
|
||||
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
|
||||
acl_packet = HCI_AclDataPacket(
|
||||
connection_handle=connection_handle,
|
||||
pb_flag=pb_flag,
|
||||
@@ -399,34 +455,12 @@ class Host(AbortableEventEmitter):
|
||||
data_total_length=data_total_length,
|
||||
data=l2cap_pdu[offset : offset + data_total_length],
|
||||
)
|
||||
logger.debug(
|
||||
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
|
||||
)
|
||||
self.queue_acl_packet(acl_packet)
|
||||
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
|
||||
packet_queue.enqueue(acl_packet)
|
||||
pb_flag = 1
|
||||
offset += data_total_length
|
||||
bytes_remaining -= data_total_length
|
||||
|
||||
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
|
||||
self.acl_packet_queue.appendleft(acl_packet)
|
||||
self.check_acl_packet_queue()
|
||||
|
||||
if len(self.acl_packet_queue):
|
||||
logger.debug(
|
||||
f'{self.acl_packets_in_flight} ACL packets in flight, '
|
||||
f'{len(self.acl_packet_queue)} in queue'
|
||||
)
|
||||
|
||||
def check_acl_packet_queue(self) -> None:
|
||||
# Send all we can (TODO: support different LE/Classic limits)
|
||||
while (
|
||||
len(self.acl_packet_queue) > 0
|
||||
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
|
||||
):
|
||||
packet = self.acl_packet_queue.pop()
|
||||
self.send_hci_packet(packet)
|
||||
self.acl_packets_in_flight += 1
|
||||
|
||||
def supports_command(self, command):
|
||||
# Find the support flag position for this command
|
||||
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
|
||||
@@ -549,7 +583,7 @@ class Host(AbortableEventEmitter):
|
||||
# This is used just for the Num_HCI_Command_Packets field, not related to
|
||||
# an actual command
|
||||
logger.debug('no-command event')
|
||||
return None
|
||||
return
|
||||
|
||||
return self.on_command_processed(event)
|
||||
|
||||
@@ -557,18 +591,17 @@ class Host(AbortableEventEmitter):
|
||||
return self.on_command_processed(event)
|
||||
|
||||
def on_hci_number_of_completed_packets_event(self, event):
|
||||
total_packets = sum(event.num_completed_packets)
|
||||
if total_packets <= self.acl_packets_in_flight:
|
||||
self.acl_packets_in_flight -= total_packets
|
||||
self.check_acl_packet_queue()
|
||||
else:
|
||||
logger.warning(
|
||||
color(
|
||||
'!!! {total_packets} completed but only '
|
||||
f'{self.acl_packets_in_flight} in flight'
|
||||
for connection_handle, num_completed_packets in zip(
|
||||
event.connection_handles, event.num_completed_packets
|
||||
):
|
||||
if not (connection := self.connections.get(connection_handle)):
|
||||
logger.warning(
|
||||
'received packet completion event for unknown handle '
|
||||
f'0x{connection_handle:04X}'
|
||||
)
|
||||
)
|
||||
self.acl_packets_in_flight = 0
|
||||
continue
|
||||
|
||||
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
|
||||
|
||||
# Classic only
|
||||
def on_hci_connection_request_event(self, event):
|
||||
@@ -721,6 +754,14 @@ class Host(AbortableEventEmitter):
|
||||
def on_hci_le_extended_advertising_report_event(self, event):
|
||||
self.on_hci_le_advertising_report_event(event)
|
||||
|
||||
def on_hci_le_advertising_set_terminated_event(self, event):
|
||||
self.emit(
|
||||
'advertising_set_termination',
|
||||
event.status,
|
||||
event.advertising_handle,
|
||||
event.connection_handle,
|
||||
)
|
||||
|
||||
def on_hci_le_cis_request_event(self, event):
|
||||
self.emit(
|
||||
'cis_request',
|
||||
|
||||
+10
-5
@@ -149,9 +149,10 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
|
||||
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
|
||||
|
||||
@@ -188,8 +189,11 @@ class LeCreditBasedChannelSpec:
|
||||
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise ValueError('max credits out of range')
|
||||
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
|
||||
raise ValueError('MTU too small')
|
||||
if (
|
||||
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
|
||||
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
|
||||
):
|
||||
raise ValueError('MTU out of range')
|
||||
if (
|
||||
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
@@ -1644,12 +1648,13 @@ class ChannelManager:
|
||||
|
||||
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||
pdu_bytes = bytes(pdu)
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
|
||||
f'{connection.peer_address}: {pdu_str}'
|
||||
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
|
||||
|
||||
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
|
||||
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
|
||||
|
||||
+752
-1
@@ -23,13 +23,21 @@ import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
import functools
|
||||
from typing import Optional, List, Union
|
||||
import logging
|
||||
from typing import Optional, List, Union, Type, Dict, Any, Tuple, cast
|
||||
|
||||
from bumble import colors
|
||||
from bumble import device
|
||||
from bumble import hci
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -220,6 +228,231 @@ class SupportedFrameDuration(enum.IntFlag):
|
||||
DURATION_10000_US_PREFERRED = 0b0010
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASE Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASE_Operation:
|
||||
'''
|
||||
See Audio Stream Control Service - 5 ASE Control operations.
|
||||
'''
|
||||
|
||||
classes: Dict[int, Type[ASE_Operation]] = {}
|
||||
op_code: int
|
||||
name: str
|
||||
fields: Optional[Sequence[Any]] = None
|
||||
ase_id: List[int]
|
||||
|
||||
class Opcode(enum.IntEnum):
|
||||
# fmt: off
|
||||
CONFIG_CODEC = 0x01
|
||||
CONFIG_QOS = 0x02
|
||||
ENABLE = 0x03
|
||||
RECEIVER_START_READY = 0x04
|
||||
DISABLE = 0x05
|
||||
RECEIVER_STOP_READY = 0x06
|
||||
UPDATE_METADATA = 0x07
|
||||
RELEASE = 0x08
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu: bytes) -> ASE_Operation:
|
||||
op_code = pdu[0]
|
||||
|
||||
cls = ASE_Operation.classes.get(op_code)
|
||||
if cls is None:
|
||||
instance = ASE_Operation(pdu)
|
||||
instance.name = ASE_Operation.Opcode(op_code).name
|
||||
instance.op_code = op_code
|
||||
return instance
|
||||
self = cls.__new__(cls)
|
||||
ASE_Operation.__init__(self, pdu)
|
||||
if self.fields is not None:
|
||||
self.init_from_bytes(pdu, 1)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def subclass(fields):
|
||||
def inner(cls: Type[ASE_Operation]):
|
||||
try:
|
||||
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
|
||||
cls.name = operation.name
|
||||
cls.op_code = operation
|
||||
except:
|
||||
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
|
||||
cls.fields = fields
|
||||
|
||||
# Register a factory for this class
|
||||
ASE_Operation.classes[cls.op_code] = cls
|
||||
|
||||
return cls
|
||||
|
||||
return inner
|
||||
|
||||
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
|
||||
if self.fields is not None and kwargs:
|
||||
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||
if pdu is None:
|
||||
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
|
||||
kwargs, self.fields
|
||||
)
|
||||
self.pdu = pdu
|
||||
|
||||
def init_from_bytes(self, pdu: bytes, offset: int):
|
||||
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.pdu
|
||||
|
||||
def __str__(self) -> str:
|
||||
result = f'{colors.color(self.name, "yellow")} '
|
||||
if fields := getattr(self, 'fields', None):
|
||||
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
|
||||
else:
|
||||
if len(self.pdu) > 1:
|
||||
result += f': {self.pdu.hex()}'
|
||||
return result
|
||||
|
||||
|
||||
@ASE_Operation.subclass(
|
||||
[
|
||||
[
|
||||
('ase_id', 1),
|
||||
('target_latency', 1),
|
||||
('target_phy', 1),
|
||||
('codec_id', hci.CodingFormat.parse_from_bytes),
|
||||
('codec_specific_configuration', 'v'),
|
||||
],
|
||||
]
|
||||
)
|
||||
class ASE_Config_Codec(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.1 - Config Codec Operation
|
||||
'''
|
||||
|
||||
target_latency: List[int]
|
||||
target_phy: List[int]
|
||||
codec_id: List[hci.CodingFormat]
|
||||
codec_specific_configuration: List[bytes]
|
||||
|
||||
|
||||
@ASE_Operation.subclass(
|
||||
[
|
||||
[
|
||||
('ase_id', 1),
|
||||
('cig_id', 1),
|
||||
('cis_id', 1),
|
||||
('sdu_interval', 3),
|
||||
('framing', 1),
|
||||
('phy', 1),
|
||||
('max_sdu', 2),
|
||||
('retransmission_number', 1),
|
||||
('max_transport_latency', 2),
|
||||
('presentation_delay', 3),
|
||||
],
|
||||
]
|
||||
)
|
||||
class ASE_Config_QOS(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.2 - Config Qos Operation
|
||||
'''
|
||||
|
||||
cig_id: List[int]
|
||||
cis_id: List[int]
|
||||
sdu_interval: List[int]
|
||||
framing: List[int]
|
||||
phy: List[int]
|
||||
max_sdu: List[int]
|
||||
retransmission_number: List[int]
|
||||
max_transport_latency: List[int]
|
||||
presentation_delay: List[int]
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||
class ASE_Enable(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.3 - Enable Operation
|
||||
'''
|
||||
|
||||
metadata: bytes
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Receiver_Start_Ready(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
|
||||
'''
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Disable(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.5 - Disable Operation
|
||||
'''
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Receiver_Stop_Ready(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
|
||||
'''
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||
class ASE_Update_Metadata(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.7 - Update Metadata Operation
|
||||
'''
|
||||
|
||||
metadata: List[bytes]
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Release(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.8 - Release Operation
|
||||
'''
|
||||
|
||||
|
||||
class AseResponseCode(enum.IntEnum):
|
||||
# fmt: off
|
||||
SUCCESS = 0x00
|
||||
UNSUPPORTED_OPCODE = 0x01
|
||||
INVALID_LENGTH = 0x02
|
||||
INVALID_ASE_ID = 0x03
|
||||
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
|
||||
INVALID_ASE_DIRECTION = 0x05
|
||||
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
|
||||
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
|
||||
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
|
||||
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
|
||||
UNSUPPORTED_METADATA = 0x0A
|
||||
REJECTED_METADATA = 0x0B
|
||||
INVALID_METADATA = 0x0C
|
||||
INSUFFICIENT_RESOURCES = 0x0D
|
||||
UNSPECIFIED_ERROR = 0x0E
|
||||
|
||||
|
||||
class AseReasonCode(enum.IntEnum):
|
||||
# fmt: off
|
||||
NONE = 0x00
|
||||
CODEC_ID = 0x01
|
||||
CODEC_SPECIFIC_CONFIGURATION = 0x02
|
||||
SDU_INTERVAL = 0x03
|
||||
FRAMING = 0x04
|
||||
PHY = 0x05
|
||||
MAXIMUM_SDU_SIZE = 0x06
|
||||
RETRANSMISSION_NUMBER = 0x07
|
||||
MAX_TRANSPORT_LATENCY = 0x08
|
||||
PRESENTATION_DELAY = 0x09
|
||||
INVALID_ASE_CIS_MAPPING = 0x0A
|
||||
|
||||
|
||||
class AudioRole(enum.IntEnum):
|
||||
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -325,6 +558,80 @@ class CodecSpecificCapabilities:
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodecSpecificConfiguration:
|
||||
'''See:
|
||||
* Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures
|
||||
* Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements
|
||||
'''
|
||||
|
||||
class Type(enum.IntEnum):
|
||||
# fmt: off
|
||||
SAMPLING_FREQUENCY = 0x01
|
||||
FRAME_DURATION = 0x02
|
||||
AUDIO_CHANNEL_ALLOCATION = 0x03
|
||||
OCTETS_PER_FRAME = 0x04
|
||||
CODEC_FRAMES_PER_SDU = 0x05
|
||||
|
||||
sampling_frequency: SamplingFrequency
|
||||
frame_duration: FrameDuration
|
||||
audio_channel_allocation: AudioLocation
|
||||
octets_per_codec_frame: int
|
||||
codec_frames_per_sdu: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
|
||||
offset = 0
|
||||
# Allowed default values.
|
||||
audio_channel_allocation = AudioLocation.NOT_ALLOWED
|
||||
codec_frames_per_sdu = 1
|
||||
while offset < len(data):
|
||||
length, type = struct.unpack_from('BB', data, offset)
|
||||
offset += 2
|
||||
value = int.from_bytes(data[offset : offset + length - 1], 'little')
|
||||
offset += length - 1
|
||||
|
||||
if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY:
|
||||
sampling_frequency = SamplingFrequency(value)
|
||||
elif type == CodecSpecificConfiguration.Type.FRAME_DURATION:
|
||||
frame_duration = FrameDuration(value)
|
||||
elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION:
|
||||
audio_channel_allocation = AudioLocation(value)
|
||||
elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME:
|
||||
octets_per_codec_frame = value
|
||||
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
|
||||
codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
return CodecSpecificConfiguration(
|
||||
sampling_frequency=sampling_frequency,
|
||||
frame_duration=frame_duration,
|
||||
audio_channel_allocation=audio_channel_allocation,
|
||||
octets_per_codec_frame=octets_per_codec_frame,
|
||||
codec_frames_per_sdu=codec_frames_per_sdu,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack(
|
||||
'<BBBBBBBBIBBHBBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
|
||||
self.sampling_frequency,
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.FRAME_DURATION,
|
||||
self.frame_duration,
|
||||
5,
|
||||
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
|
||||
self.audio_channel_allocation,
|
||||
3,
|
||||
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
|
||||
self.octets_per_codec_frame,
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.codec_frames_per_sdu,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PacRecord:
|
||||
coding_format: hci.CodingFormat
|
||||
@@ -452,6 +759,429 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
||||
super().__init__(characteristics)
|
||||
|
||||
|
||||
class AseStateMachine(gatt.Characteristic):
|
||||
class State(enum.IntEnum):
|
||||
# fmt: off
|
||||
IDLE = 0x00
|
||||
CODEC_CONFIGURED = 0x01
|
||||
QOS_CONFIGURED = 0x02
|
||||
ENABLING = 0x03
|
||||
STREAMING = 0x04
|
||||
DISABLING = 0x05
|
||||
RELEASING = 0x06
|
||||
|
||||
cis_link: Optional[device.CisLink] = None
|
||||
|
||||
# Additional parameters in CODEC_CONFIGURED State
|
||||
preferred_framing = 0 # Unframed PDU supported
|
||||
preferred_phy = 0
|
||||
preferred_retransmission_number = 13
|
||||
preferred_max_transport_latency = 100
|
||||
supported_presentation_delay_min = 0
|
||||
supported_presentation_delay_max = 0
|
||||
preferred_presentation_delay_min = 0
|
||||
preferred_presentation_delay_max = 0
|
||||
codec_id = hci.CodingFormat(hci.CodecID.LC3)
|
||||
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
|
||||
|
||||
# Additional parameters in QOS_CONFIGURED State
|
||||
cig_id = 0
|
||||
cis_id = 0
|
||||
sdu_interval = 0
|
||||
framing = 0
|
||||
phy = 0
|
||||
max_sdu = 0
|
||||
retransmission_number = 0
|
||||
max_transport_latency = 0
|
||||
presentation_delay = 0
|
||||
|
||||
# Additional parameters in ENABLING, STREAMING, DISABLING State
|
||||
# TODO: Parse this
|
||||
metadata = b''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: AudioRole,
|
||||
ase_id: int,
|
||||
service: AudioStreamControlService,
|
||||
) -> None:
|
||||
self.service = service
|
||||
self.ase_id = ase_id
|
||||
self._state = AseStateMachine.State.IDLE
|
||||
self.role = role
|
||||
|
||||
uuid = (
|
||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||
if role == AudioRole.SINK
|
||||
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
||||
)
|
||||
super().__init__(
|
||||
uuid=uuid,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=gatt.CharacteristicValue(read=self.on_read),
|
||||
)
|
||||
|
||||
self.service.device.on('cis_request', self.on_cis_request)
|
||||
self.service.device.on('cis_establishment', self.on_cis_establishment)
|
||||
|
||||
def on_cis_request(
|
||||
self,
|
||||
acl_connection: device.Connection,
|
||||
cis_handle: int,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
if cis_id == self.cis_id and self.state == self.State.ENABLING:
|
||||
acl_connection.abort_on(
|
||||
'flush', self.service.device.accept_cis_request(cis_handle)
|
||||
)
|
||||
|
||||
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
|
||||
if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING:
|
||||
self.state = self.State.STREAMING
|
||||
self.cis_link = cis_link
|
||||
|
||||
async def post_cis_established():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
data_path_id=0x00, # Fixed HCI
|
||||
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=0,
|
||||
codec_configuration=b'',
|
||||
)
|
||||
)
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
cis_link.acl_connection.abort_on('flush', post_cis_established())
|
||||
|
||||
def on_config_codec(
|
||||
self,
|
||||
target_latency: int,
|
||||
target_phy: int,
|
||||
codec_id: hci.CodingFormat,
|
||||
codec_specific_configuration: bytes,
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
self.State.IDLE,
|
||||
self.State.CODEC_CONFIGURED,
|
||||
self.State.QOS_CONFIGURED,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.max_transport_latency = target_latency
|
||||
self.phy = target_phy
|
||||
self.codec_id = codec_id
|
||||
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
||||
self.codec_specific_configuration = codec_specific_configuration
|
||||
else:
|
||||
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
|
||||
codec_specific_configuration
|
||||
)
|
||||
|
||||
self.state = self.State.CODEC_CONFIGURED
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_config_qos(
|
||||
self,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
sdu_interval: int,
|
||||
framing: int,
|
||||
phy: int,
|
||||
max_sdu: int,
|
||||
retransmission_number: int,
|
||||
max_transport_latency: int,
|
||||
presentation_delay: int,
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.CODEC_CONFIGURED,
|
||||
AseStateMachine.State.QOS_CONFIGURED,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.cig_id = cig_id
|
||||
self.cis_id = cis_id
|
||||
self.sdu_interval = sdu_interval
|
||||
self.framing = framing
|
||||
self.phy = phy
|
||||
self.max_sdu = max_sdu
|
||||
self.retransmission_number = retransmission_number
|
||||
self.max_transport_latency = max_transport_latency
|
||||
self.presentation_delay = presentation_delay
|
||||
|
||||
self.state = self.State.QOS_CONFIGURED
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.QOS_CONFIGURED:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.metadata = metadata
|
||||
self.state = self.State.ENABLING
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.ENABLING:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.STREAMING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.ENABLING,
|
||||
AseStateMachine.State.STREAMING,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.DISABLING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.DISABLING:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.QOS_CONFIGURED
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_update_metadata(
|
||||
self, metadata: bytes
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.ENABLING,
|
||||
AseStateMachine.State.STREAMING,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.metadata = metadata
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state == AseStateMachine.State.IDLE:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.RELEASING
|
||||
|
||||
async def remove_cis_async():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
)
|
||||
)
|
||||
self.state = self.State.IDLE
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
self.service.device.abort_on('flush', remove_cis_async())
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
|
||||
self._state = new_state
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
|
||||
|
||||
if self.state == self.State.CODEC_CONFIGURED:
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
additional_parameters = (
|
||||
struct.pack(
|
||||
'<BBBH',
|
||||
self.preferred_framing,
|
||||
self.preferred_phy,
|
||||
self.preferred_retransmission_number,
|
||||
self.preferred_max_transport_latency,
|
||||
)
|
||||
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
|
||||
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
|
||||
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
|
||||
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
|
||||
+ bytes(self.codec_id)
|
||||
+ bytes([len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
)
|
||||
elif self.state == self.State.QOS_CONFIGURED:
|
||||
additional_parameters = (
|
||||
bytes([self.cig_id, self.cis_id])
|
||||
+ self.sdu_interval.to_bytes(3, 'little')
|
||||
+ struct.pack(
|
||||
'<BBHBH',
|
||||
self.framing,
|
||||
self.phy,
|
||||
self.max_sdu,
|
||||
self.retransmission_number,
|
||||
self.max_transport_latency,
|
||||
)
|
||||
+ self.presentation_delay.to_bytes(3, 'little')
|
||||
)
|
||||
elif self.state in (
|
||||
self.State.ENABLING,
|
||||
self.State.STREAMING,
|
||||
self.State.DISABLING,
|
||||
):
|
||||
additional_parameters = (
|
||||
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
|
||||
)
|
||||
else:
|
||||
additional_parameters = b''
|
||||
|
||||
return bytes([self.ase_id, self.state]) + additional_parameters
|
||||
|
||||
@value.setter
|
||||
def value(self, _new_value):
|
||||
# Readonly. Do nothing in the setter.
|
||||
pass
|
||||
|
||||
def on_read(self, _: device.Connection) -> bytes:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
|
||||
f'state={self._state.name})'
|
||||
)
|
||||
|
||||
|
||||
class AudioStreamControlService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
|
||||
|
||||
ase_state_machines: Dict[int, AseStateMachine]
|
||||
ase_control_point: gatt.Characteristic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: device.Device,
|
||||
source_ase_id: Sequence[int] = [],
|
||||
sink_ase_id: Sequence[int] = [],
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.ase_state_machines = {
|
||||
**{
|
||||
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
|
||||
for id in sink_ase_id
|
||||
},
|
||||
**{
|
||||
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
|
||||
for id in source_ase_id
|
||||
},
|
||||
} # ASE state machines, by ASE ID
|
||||
|
||||
self.ase_control_point = gatt.Characteristic(
|
||||
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.WRITEABLE,
|
||||
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
|
||||
)
|
||||
|
||||
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
|
||||
|
||||
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
|
||||
if ase := self.ase_state_machines.get(ase_id):
|
||||
handler = getattr(ase, 'on_' + opcode.name.lower())
|
||||
return (ase_id, *handler(*args))
|
||||
else:
|
||||
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
|
||||
|
||||
def on_write_ase_control_point(self, connection, data):
|
||||
operation = ASE_Operation.from_bytes(data)
|
||||
responses = []
|
||||
logger.debug(f'*** ASCS Write {operation} ***')
|
||||
|
||||
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.target_latency,
|
||||
operation.target_phy,
|
||||
operation.codec_id,
|
||||
operation.codec_specific_configuration,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.cig_id,
|
||||
operation.cis_id,
|
||||
operation.sdu_interval,
|
||||
operation.framing,
|
||||
operation.phy,
|
||||
operation.max_sdu,
|
||||
operation.retransmission_number,
|
||||
operation.max_transport_latency,
|
||||
operation.presentation_delay,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code in (
|
||||
ASE_Operation.Opcode.ENABLE,
|
||||
ASE_Operation.Opcode.UPDATE_METADATA,
|
||||
):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.metadata,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code in (
|
||||
ASE_Operation.Opcode.RECEIVER_START_READY,
|
||||
ASE_Operation.Opcode.DISABLE,
|
||||
ASE_Operation.Opcode.RECEIVER_STOP_READY,
|
||||
ASE_Operation.Opcode.RELEASE,
|
||||
):
|
||||
for ase_id in operation.ase_id:
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||
|
||||
control_point_notification = bytes(
|
||||
[operation.op_code, len(responses)]
|
||||
) + b''.join(map(bytes, responses))
|
||||
self.device.abort_on(
|
||||
'flush',
|
||||
self.device.notify_subscribers(
|
||||
self.ase_control_point, control_point_notification
|
||||
),
|
||||
)
|
||||
|
||||
for ase_id, *_ in responses:
|
||||
if ase := self.ase_state_machines.get(ase_id):
|
||||
self.device.abort_on(
|
||||
'flush',
|
||||
self.device.notify_subscribers(ase, ase.value),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -494,3 +1224,24 @@ class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.source_audio_locations = characteristics[0]
|
||||
|
||||
|
||||
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = AudioStreamControlService
|
||||
|
||||
sink_ase: List[gatt_client.CharacteristicProxy]
|
||||
source_ase: List[gatt_client.CharacteristicProxy]
|
||||
ase_control_point: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.sink_ase = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||
)
|
||||
self.source_ase = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
||||
)
|
||||
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble.profiles import csip
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
|
||||
) -> None:
|
||||
self.coordinated_set_identification_service = (
|
||||
coordinated_set_identification_service
|
||||
)
|
||||
super().__init__(
|
||||
characteristics=[],
|
||||
included_services=[coordinated_set_identification_service],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CommonAudioServiceService
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
+62
-4
@@ -21,6 +21,9 @@ import enum
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
from bumble import core
|
||||
from bumble import crypto
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
@@ -43,9 +46,43 @@ class MemberLock(enum.IntEnum):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# Crypto Toolbox
|
||||
# -----------------------------------------------------------------------------
|
||||
# TODO: Implement RSI Generator
|
||||
def s1(m: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
|
||||
'''
|
||||
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
|
||||
|
||||
|
||||
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.4 k1 derivation function.
|
||||
'''
|
||||
t = crypto.aes_cmac(n[::-1], salt[::-1])
|
||||
return crypto.aes_cmac(p[::-1], t)[::-1]
|
||||
|
||||
|
||||
def sef(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
|
||||
'''
|
||||
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
|
||||
|
||||
|
||||
def sih(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
|
||||
'''
|
||||
return crypto.e(k, r + bytes(13))[:3]
|
||||
|
||||
|
||||
def generate_rsi(sirk: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
|
||||
'''
|
||||
prand = crypto.generate_prand()
|
||||
return sih(sirk, prand) + prand
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -54,6 +91,7 @@ class MemberLock(enum.IntEnum):
|
||||
class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
|
||||
set_identity_resolving_key: bytes
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
|
||||
@@ -62,19 +100,21 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
def __init__(
|
||||
self,
|
||||
set_identity_resolving_key: bytes,
|
||||
set_identity_resolving_key_type: SirkType,
|
||||
coordinated_set_size: Optional[int] = None,
|
||||
set_member_lock: Optional[MemberLock] = None,
|
||||
set_member_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
characteristics = []
|
||||
|
||||
self.set_identity_resolving_key = set_identity_resolving_key
|
||||
self.set_identity_resolving_key_type = set_identity_resolving_key_type
|
||||
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
# TODO: Implement encrypted SIRK reader.
|
||||
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
|
||||
value=gatt.CharacteristicValue(read=self.on_sirk_read),
|
||||
)
|
||||
characteristics.append(self.set_identity_resolving_key_characteristic)
|
||||
|
||||
@@ -112,6 +152,24 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
def on_sirk_read(self, _connection: device.Connection) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key
|
||||
else:
|
||||
raise NotImplementedError('TODO: Pending async Characteristic read.')
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
|
||||
generate_rsi(self.set_identity_resolving_key),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
|
||||
+31
-20
@@ -118,8 +118,8 @@ CRC_TABLE = bytes([
|
||||
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
|
||||
])
|
||||
|
||||
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
|
||||
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
|
||||
RFCOMM_DEFAULT_WINDOW_SIZE = 16
|
||||
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
|
||||
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
||||
@@ -438,20 +438,24 @@ class DLC(EventEmitter):
|
||||
multiplexer: Multiplexer,
|
||||
dlci: int,
|
||||
max_frame_size: int,
|
||||
initial_tx_credits: int,
|
||||
window_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplexer = multiplexer
|
||||
self.dlci = dlci
|
||||
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
|
||||
self.rx_threshold = self.rx_credits // 2
|
||||
self.tx_credits = initial_tx_credits
|
||||
self.max_frame_size = max_frame_size
|
||||
self.window_size = window_size
|
||||
self.rx_credits = window_size
|
||||
self.rx_threshold = window_size // 2
|
||||
self.tx_credits = window_size
|
||||
self.tx_buffer = b''
|
||||
self.state = DLC.State.INIT
|
||||
self.role = multiplexer.role
|
||||
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
|
||||
self.sink = None
|
||||
self.connection_result = None
|
||||
self.drained = asyncio.Event()
|
||||
self.drained.set()
|
||||
|
||||
# Compute the MTU
|
||||
max_overhead = 4 + 1 # header with 2-byte length + fcs
|
||||
@@ -537,11 +541,11 @@ class DLC(EventEmitter):
|
||||
if len(data) and self.sink:
|
||||
self.sink(data) # pylint: disable=not-callable
|
||||
|
||||
# Update the credits
|
||||
if self.rx_credits > 0:
|
||||
self.rx_credits -= 1
|
||||
else:
|
||||
logger.warning(color('!!! received frame with no rx credits', 'red'))
|
||||
# Update the credits
|
||||
if self.rx_credits > 0:
|
||||
self.rx_credits -= 1
|
||||
else:
|
||||
logger.warning(color('!!! received frame with no rx credits', 'red'))
|
||||
|
||||
# Check if there's anything to send (including credits)
|
||||
self.process_tx()
|
||||
@@ -580,9 +584,9 @@ class DLC(EventEmitter):
|
||||
cl=0xE0,
|
||||
priority=7,
|
||||
ack_timer=0,
|
||||
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
|
||||
max_frame_size=self.max_frame_size,
|
||||
max_retransmissions=0,
|
||||
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
|
||||
logger.debug(f'>>> PN Response: {pn}')
|
||||
@@ -591,7 +595,7 @@ class DLC(EventEmitter):
|
||||
|
||||
def rx_credits_needed(self) -> int:
|
||||
if self.rx_credits <= self.rx_threshold:
|
||||
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
|
||||
return self.window_size - self.rx_credits
|
||||
|
||||
return 0
|
||||
|
||||
@@ -631,6 +635,8 @@ class DLC(EventEmitter):
|
||||
)
|
||||
|
||||
rx_credits_needed = 0
|
||||
if not self.tx_buffer:
|
||||
self.drained.set()
|
||||
|
||||
# Stream protocol
|
||||
def write(self, data: Union[bytes, str]) -> None:
|
||||
@@ -643,11 +649,11 @@ class DLC(EventEmitter):
|
||||
raise ValueError('write only accept bytes or strings')
|
||||
|
||||
self.tx_buffer += data
|
||||
self.drained.clear()
|
||||
self.process_tx()
|
||||
|
||||
def drain(self) -> None:
|
||||
# TODO
|
||||
pass
|
||||
async def drain(self) -> None:
|
||||
await self.drained.wait()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'DLC(dlci={self.dlci},state={self.state.name})'
|
||||
@@ -843,7 +849,12 @@ class Multiplexer(EventEmitter):
|
||||
)
|
||||
await self.disconnection_result
|
||||
|
||||
async def open_dlc(self, channel: int) -> DLC:
|
||||
async def open_dlc(
|
||||
self,
|
||||
channel: int,
|
||||
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
|
||||
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
|
||||
) -> DLC:
|
||||
if self.state != Multiplexer.State.CONNECTED:
|
||||
if self.state == Multiplexer.State.OPENING:
|
||||
raise InvalidStateError('open already in progress')
|
||||
@@ -855,9 +866,9 @@ class Multiplexer(EventEmitter):
|
||||
cl=0xF0,
|
||||
priority=7,
|
||||
ack_timer=0,
|
||||
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
|
||||
max_frame_size=max_frame_size,
|
||||
max_retransmissions=0,
|
||||
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
|
||||
window_size=window_size,
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
|
||||
logger.debug(f'>>> Sending MCC: {pn}')
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from ..snoop import create_snooper
|
||||
@@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
|
||||
async def open_transport(name: str) -> Transport:
|
||||
"""
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types).
|
||||
The name must be <type>:<metadata><parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types), and
|
||||
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
|
||||
enclosed in [].
|
||||
If there are not metadata or parameter, the : after the <type> may be omitted.
|
||||
Examples:
|
||||
* usb:0
|
||||
* usb:[driver=rtk]0
|
||||
* android-netsim
|
||||
|
||||
The supported types are:
|
||||
* serial
|
||||
* udp
|
||||
@@ -71,87 +80,106 @@ async def open_transport(name: str) -> Transport:
|
||||
* android-netsim
|
||||
"""
|
||||
|
||||
return _wrap_transport(await _open_transport(name))
|
||||
scheme, *tail = name.split(':', 1)
|
||||
spec = tail[0] if tail else None
|
||||
if spec:
|
||||
# Metadata may precede the spec
|
||||
if spec.startswith('['):
|
||||
metadata_str, *tail = spec[1:].split(']')
|
||||
spec = tail[0] if tail else None
|
||||
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
transport = await _open_transport(scheme, spec)
|
||||
if metadata:
|
||||
transport.source.metadata = { # type: ignore[attr-defined]
|
||||
**metadata,
|
||||
**getattr(transport.source, 'metadata', {}),
|
||||
}
|
||||
# pylint: disable=line-too-long
|
||||
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
|
||||
|
||||
return _wrap_transport(transport)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def _open_transport(name: str) -> Transport:
|
||||
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
scheme, *spec = name.split(':', 1)
|
||||
if scheme == 'serial' and spec:
|
||||
from .serial import open_serial_transport
|
||||
|
||||
return await open_serial_transport(spec[0])
|
||||
return await open_serial_transport(spec)
|
||||
|
||||
if scheme == 'udp' and spec:
|
||||
from .udp import open_udp_transport
|
||||
|
||||
return await open_udp_transport(spec[0])
|
||||
return await open_udp_transport(spec)
|
||||
|
||||
if scheme == 'tcp-client' and spec:
|
||||
from .tcp_client import open_tcp_client_transport
|
||||
|
||||
return await open_tcp_client_transport(spec[0])
|
||||
return await open_tcp_client_transport(spec)
|
||||
|
||||
if scheme == 'tcp-server' and spec:
|
||||
from .tcp_server import open_tcp_server_transport
|
||||
|
||||
return await open_tcp_server_transport(spec[0])
|
||||
return await open_tcp_server_transport(spec)
|
||||
|
||||
if scheme == 'ws-client' and spec:
|
||||
from .ws_client import open_ws_client_transport
|
||||
|
||||
return await open_ws_client_transport(spec[0])
|
||||
return await open_ws_client_transport(spec)
|
||||
|
||||
if scheme == 'ws-server' and spec:
|
||||
from .ws_server import open_ws_server_transport
|
||||
|
||||
return await open_ws_server_transport(spec[0])
|
||||
return await open_ws_server_transport(spec)
|
||||
|
||||
if scheme == 'pty':
|
||||
from .pty import open_pty_transport
|
||||
|
||||
return await open_pty_transport(spec[0] if spec else None)
|
||||
return await open_pty_transport(spec)
|
||||
|
||||
if scheme == 'file':
|
||||
from .file import open_file_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_file_transport(spec[0])
|
||||
return await open_file_transport(spec)
|
||||
|
||||
if scheme == 'vhci':
|
||||
from .vhci import open_vhci_transport
|
||||
|
||||
return await open_vhci_transport(spec[0] if spec else None)
|
||||
return await open_vhci_transport(spec)
|
||||
|
||||
if scheme == 'hci-socket':
|
||||
from .hci_socket import open_hci_socket_transport
|
||||
|
||||
return await open_hci_socket_transport(spec[0] if spec else None)
|
||||
return await open_hci_socket_transport(spec)
|
||||
|
||||
if scheme == 'usb':
|
||||
from .usb import open_usb_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_usb_transport(spec[0])
|
||||
assert spec
|
||||
return await open_usb_transport(spec)
|
||||
|
||||
if scheme == 'pyusb':
|
||||
from .pyusb import open_pyusb_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_pyusb_transport(spec[0])
|
||||
assert spec
|
||||
return await open_pyusb_transport(spec)
|
||||
|
||||
if scheme == 'android-emulator':
|
||||
from .android_emulator import open_android_emulator_transport
|
||||
|
||||
return await open_android_emulator_transport(spec[0] if spec else None)
|
||||
return await open_android_emulator_transport(spec)
|
||||
|
||||
if scheme == 'android-netsim':
|
||||
from .android_netsim import open_android_netsim_transport
|
||||
|
||||
return await open_android_netsim_transport(spec[0] if spec else None)
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
raise ValueError('unknown transport scheme')
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = '8554'
|
||||
if spec is not None:
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
|
||||
@@ -21,7 +21,7 @@ import struct
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
from typing import ContextManager, Tuple, Optional, Protocol, Dict
|
||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
|
||||
@@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
|
||||
) from error
|
||||
|
||||
# Compute the adapter index
|
||||
if spec is None:
|
||||
adapter_index = 0
|
||||
else:
|
||||
adapter_index = int(spec)
|
||||
adapter_index = int(spec) if spec else 0
|
||||
|
||||
# Bind the socket
|
||||
# NOTE: since Python doesn't support binding with the required address format (yet),
|
||||
|
||||
@@ -108,7 +108,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
|
||||
)
|
||||
|
||||
READ_SIZE = 1024
|
||||
READ_SIZE = 4096
|
||||
|
||||
class UsbPacketSink:
|
||||
def __init__(self, device, acl_out):
|
||||
|
||||
@@ -7,16 +7,36 @@ throughput and/or latency between two devices.
|
||||
# General Usage
|
||||
|
||||
```
|
||||
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
|
||||
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--device-config FILENAME Device configuration file
|
||||
--role [sender|receiver|ping|pong]
|
||||
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
|
||||
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
|
||||
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (server role)
|
||||
-sd, --start-delay SECONDS Start delay (server role)
|
||||
--extended-data-length TEXT Request a data length upon connection,
|
||||
specified as tx_octets/tx_time
|
||||
--rfcomm-channel INTEGER RFComm channel to use
|
||||
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
|
||||
--rfcomm-channel is not 0)
|
||||
--l2cap-psm INTEGER L2CAP PSM to use
|
||||
--l2cap-mtu INTEGER L2CAP MTU to use
|
||||
--l2cap-mps INTEGER L2CAP MPS to use
|
||||
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
|
||||
the peer
|
||||
-s, --packet-size SIZE Packet size (client or ping role)
|
||||
[8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (client or ping role)
|
||||
-sd, --start-delay SECONDS Start delay (client or ping role)
|
||||
--repeat N Repeat the run N times (client and ping
|
||||
roles)(0, which is the fault, to run just
|
||||
once)
|
||||
--repeat-delay SECONDS Delay, in seconds, between repeats
|
||||
--pace MILLISECONDS Wait N milliseconds between packets (0,
|
||||
which is the fault, to send as fast as
|
||||
possible)
|
||||
--linger Don't exit at the end of a run (server and
|
||||
pong roles)
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
@@ -35,17 +55,18 @@ Options:
|
||||
--connection-interval, --ci CONNECTION_INTERVAL
|
||||
Connection interval (in ms)
|
||||
--phy [1m|2m|coded] PHY to use
|
||||
--authenticate Authenticate (RFComm only)
|
||||
--encrypt Encrypt the connection (RFComm only)
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
|
||||
To test once device against another, one of the two devices must be running
|
||||
To test once device against another, one of the two devices must be running
|
||||
the ``peripheral`` command and the other the ``central`` command. The device
|
||||
running the ``peripheral`` command will accept connections from the device
|
||||
running the ``central`` command.
|
||||
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
using the ``--peripheral`` option. The address will be printed by the Peripheral when
|
||||
it starts.
|
||||
|
||||
@@ -83,7 +104,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
$ bumble-bench central usb:1
|
||||
```
|
||||
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
connecting to the Peripheral running a Receiver, as a GATT server.
|
||||
|
||||
!!! example "L2CAP Throughput"
|
||||
|
||||
@@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly.
|
||||
This may include, for instance, loading a Firmware image or patch,
|
||||
loading a configuration.
|
||||
|
||||
By default, drivers will be automatically probed to determine if they should be
|
||||
used with particular HCI controller.
|
||||
When the transport for an HCI controller is instantiated from a transport name,
|
||||
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
|
||||
metadata portion of the transport name. For example,
|
||||
``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the
|
||||
first USB device, even if a normal probe would not have selected it based on the
|
||||
USB vendor ID and product ID.
|
||||
|
||||
Drivers included in the module are:
|
||||
|
||||
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
|
||||
@@ -1,13 +1,16 @@
|
||||
REALTEK DRIVER
|
||||
==============
|
||||
|
||||
This driver supports loading firmware images and optional config data to
|
||||
This driver supports loading firmware images and optional config data to
|
||||
USB dongles with a Realtek chipset.
|
||||
A number of USB dongles are supported, but likely not all.
|
||||
When using a USB dongle, the USB product ID and manufacturer ID are used
|
||||
When using a USB dongle, the USB product ID and vendor ID are used
|
||||
to find whether a matching set of firmware image and config data
|
||||
is needed for that specific model. If a match exists, the driver will try
|
||||
load the firmware image and, if needed, config data.
|
||||
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
|
||||
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
|
||||
``usb:0`` for the first USB device).
|
||||
The driver will look for those files by name, in order, in:
|
||||
|
||||
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Bumble HID Keyboard",
|
||||
"class_of_device": 9664,
|
||||
"keystore": "JsonKeyStore"
|
||||
}
|
||||
@@ -40,9 +40,9 @@
|
||||
}
|
||||
}
|
||||
function onMouseMove(event) {
|
||||
//console.log(event.clientX, event.clientY)
|
||||
mouseInfo.innerText = `MOUSE: x=${event.clientX}, y=${event.clientY}`
|
||||
send({ type:'mousemove', x: event.clientX, y: event.clientY })
|
||||
//console.log(event.movementX, event.movementY)
|
||||
mouseInfo.innerText = `MOUSE: x=${event.movementX}, y=${event.movementY}`
|
||||
send({ type:'mousemove', x: event.movementX, y: event.movementY })
|
||||
}
|
||||
|
||||
function onKeyDown(event) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"name": "Bumble-LEA",
|
||||
"keystore": "JsonKeyStore",
|
||||
"address": "F0:F1:F2:F3:F4:FA",
|
||||
"advertising_interval": 100
|
||||
}
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
OwnAddressType,
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command,
|
||||
)
|
||||
from bumble.profiles.cap import CommonAudioServiceService
|
||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print(
|
||||
'Usage: run_cig_setup.py <config-file>'
|
||||
'<transport-spec-for-device-1> <transport-spec-for-device-2>'
|
||||
)
|
||||
print(
|
||||
'example: run_cig_setup.py device1.json'
|
||||
'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402'
|
||||
)
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
hci_transports = await asyncio.gather(
|
||||
open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3])
|
||||
)
|
||||
print('<<< connected')
|
||||
|
||||
devices = [
|
||||
Device.from_config_file_with_hci(
|
||||
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||
)
|
||||
for hci_transport in hci_transports
|
||||
]
|
||||
|
||||
sirk = secrets.token_bytes(16)
|
||||
|
||||
for i, device in enumerate(devices):
|
||||
device.random_address = Address(secrets.token_bytes(6))
|
||||
await device.power_on()
|
||||
csis = CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=sirk,
|
||||
set_identity_resolving_key_type=SirkType.PLAINTEXT,
|
||||
coordinated_set_size=2,
|
||||
)
|
||||
device.add_service(CommonAudioServiceService(csis))
|
||||
advertising_data = (
|
||||
bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(f'Bumble LE Audio-{i}', 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes(
|
||||
[
|
||||
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
|
||||
| AdvertisingData.BR_EDR_HOST_FLAG
|
||||
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(CoordinatedSetIdentificationService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
+ csis.get_advertising_data()
|
||||
)
|
||||
await device.start_extended_advertising(
|
||||
advertising_properties=(
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
|
||||
),
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
advertising_data=advertising_data,
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[hci_transport.source.terminated for hci_transport in hci_transports]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,748 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import websockets
|
||||
from bumble.colors import color
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
|
||||
BT_HIDP_PROTOCOL_ID,
|
||||
UUID,
|
||||
)
|
||||
from bumble.hci import Address
|
||||
from bumble.hid import (
|
||||
Device as HID_Device,
|
||||
HID_CONTROL_PSM,
|
||||
HID_INTERRUPT_PSM,
|
||||
Message,
|
||||
)
|
||||
from bumble.sdp import (
|
||||
Client as SDP_Client,
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
SDP_PUBLIC_BROWSE_ROOT,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_ALL_ATTRIBUTES_RANGE,
|
||||
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SDP attributes for Bluetooth HID devices
|
||||
SDP_HID_SERVICE_NAME_ATTRIBUTE_ID = 0x0100
|
||||
SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID = 0x0101
|
||||
SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID = 0x0102
|
||||
SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID = 0x0200 # [DEPRECATED]
|
||||
SDP_HID_PARSER_VERSION_ATTRIBUTE_ID = 0x0201
|
||||
SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID = 0x0202
|
||||
SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID = 0x0203
|
||||
SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID = 0x0204
|
||||
SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID = 0x0205
|
||||
SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0x0206
|
||||
SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID = 0x0207
|
||||
SDP_HID_SDP_DISABLE_ATTRIBUTE_ID = 0x0208 # [DEPRECATED]
|
||||
SDP_HID_BATTERY_POWER_ATTRIBUTE_ID = 0x0209
|
||||
SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID = 0x020A
|
||||
SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID = 0x020B # DEPRECATED]
|
||||
SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID = 0x020C
|
||||
SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID = 0x020D
|
||||
SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID = 0x020E
|
||||
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F
|
||||
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210
|
||||
|
||||
# Refer to HID profile specification v1.1.1, "5.3 Service Discovery Protocol (SDP)" for details
|
||||
# HID SDP attribute values
|
||||
LANGUAGE = 0x656E # 0x656E uint16 “en” (English)
|
||||
ENCODING = 0x6A # 0x006A uint16 UTF-8 encoding
|
||||
PRIMARY_LANGUAGE_BASE_ID = 0x100 # 0x0100 uint16 PrimaryLanguageBaseID
|
||||
VERSION_NUMBER = 0x0101 # 0x0101 uint16 version number (v1.1)
|
||||
SERVICE_NAME = b'Bumble HID'
|
||||
SERVICE_DESCRIPTION = b'Bumble'
|
||||
PROVIDER_NAME = b'Bumble'
|
||||
HID_PARSER_VERSION = 0x0111 # uint16 0x0111 (v1.1.1)
|
||||
HID_DEVICE_SUBCLASS = 0xC0 # Combo keyboard/pointing device
|
||||
HID_COUNTRY_CODE = 0x21 # 0x21 Uint8, USA
|
||||
HID_VIRTUAL_CABLE = True # Virtual cable enabled
|
||||
HID_RECONNECT_INITIATE = True # Reconnect initiate enabled
|
||||
REPORT_DESCRIPTOR_TYPE = 0x22 # 0x22 Type = Report Descriptor
|
||||
HID_LANGID_BASE_LANGUAGE = 0x0409 # 0x0409 Language = English (United States)
|
||||
HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET = 0x100 # 0x0100 Default
|
||||
HID_BATTERY_POWER = True # Battery power enabled
|
||||
HID_REMOTE_WAKE = True # Remote wake enabled
|
||||
HID_SUPERVISION_TIMEOUT = 0xC80 # uint16 0xC80 (2s)
|
||||
HID_NORMALLY_CONNECTABLE = True # Normally connectable enabled
|
||||
HID_BOOT_DEVICE = True # Boot device support enabled
|
||||
HID_SSR_HOST_MAX_LATENCY = 0x640 # uint16 0x640 (1s)
|
||||
HID_SSR_HOST_MIN_TIMEOUT = 0xC80 # uint16 0xC80 (2s)
|
||||
HID_REPORT_MAP = bytes( # Text String, 50 Octet Report Descriptor
|
||||
# pylint: disable=line-too-long
|
||||
[
|
||||
0x05,
|
||||
0x01, # Usage Page (Generic Desktop Ctrls)
|
||||
0x09,
|
||||
0x06, # Usage (Keyboard)
|
||||
0xA1,
|
||||
0x01, # Collection (Application)
|
||||
0x85,
|
||||
0x01, # . Report ID (1)
|
||||
0x05,
|
||||
0x07, # . Usage Page (Kbrd/Keypad)
|
||||
0x19,
|
||||
0xE0, # . Usage Minimum (0xE0)
|
||||
0x29,
|
||||
0xE7, # . Usage Maximum (0xE7)
|
||||
0x15,
|
||||
0x00, # . Logical Minimum (0)
|
||||
0x25,
|
||||
0x01, # . Logical Maximum (1)
|
||||
0x75,
|
||||
0x01, # . Report Size (1)
|
||||
0x95,
|
||||
0x08, # . Report Count (8)
|
||||
0x81,
|
||||
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x95,
|
||||
0x01, # . Report Count (1)
|
||||
0x75,
|
||||
0x08, # . Report Size (8)
|
||||
0x81,
|
||||
0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x95,
|
||||
0x05, # . Report Count (5)
|
||||
0x75,
|
||||
0x01, # . Report Size (1)
|
||||
0x05,
|
||||
0x08, # . Usage Page (LEDs)
|
||||
0x19,
|
||||
0x01, # . Usage Minimum (Num Lock)
|
||||
0x29,
|
||||
0x05, # . Usage Maximum (Kana)
|
||||
0x91,
|
||||
0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
|
||||
0x95,
|
||||
0x01, # . Report Count (1)
|
||||
0x75,
|
||||
0x03, # . Report Size (3)
|
||||
0x91,
|
||||
0x03, # . Output (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
|
||||
0x95,
|
||||
0x06, # . Report Count (6)
|
||||
0x75,
|
||||
0x08, # . Report Size (8)
|
||||
0x15,
|
||||
0x00, # . Logical Minimum (0)
|
||||
0x25,
|
||||
0x65, # . Logical Maximum (101)
|
||||
0x05,
|
||||
0x07, # . Usage Page (Kbrd/Keypad)
|
||||
0x19,
|
||||
0x00, # . Usage Minimum (0x00)
|
||||
0x29,
|
||||
0x65, # . Usage Maximum (0x65)
|
||||
0x81,
|
||||
0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0xC0, # End Collection
|
||||
0x05,
|
||||
0x01, # Usage Page (Generic Desktop Ctrls)
|
||||
0x09,
|
||||
0x02, # Usage (Mouse)
|
||||
0xA1,
|
||||
0x01, # Collection (Application)
|
||||
0x85,
|
||||
0x02, # . Report ID (2)
|
||||
0x09,
|
||||
0x01, # . Usage (Pointer)
|
||||
0xA1,
|
||||
0x00, # . Collection (Physical)
|
||||
0x05,
|
||||
0x09, # . Usage Page (Button)
|
||||
0x19,
|
||||
0x01, # . Usage Minimum (0x01)
|
||||
0x29,
|
||||
0x03, # . Usage Maximum (0x03)
|
||||
0x15,
|
||||
0x00, # . Logical Minimum (0)
|
||||
0x25,
|
||||
0x01, # . Logical Maximum (1)
|
||||
0x95,
|
||||
0x03, # . Report Count (3)
|
||||
0x75,
|
||||
0x01, # . Report Size (1)
|
||||
0x81,
|
||||
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x95,
|
||||
0x01, # . Report Count (1)
|
||||
0x75,
|
||||
0x05, # . Report Size (5)
|
||||
0x81,
|
||||
0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x05,
|
||||
0x01, # . Usage Page (Generic Desktop Ctrls)
|
||||
0x09,
|
||||
0x30, # . Usage (X)
|
||||
0x09,
|
||||
0x31, # . Usage (Y)
|
||||
0x15,
|
||||
0x81, # . Logical Minimum (-127)
|
||||
0x25,
|
||||
0x7F, # . Logical Maximum (127)
|
||||
0x75,
|
||||
0x08, # . Report Size (8)
|
||||
0x95,
|
||||
0x02, # . Report Count (2)
|
||||
0x81,
|
||||
0x06, # . Input (Data,Var,Rel,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0xC0, # . End Collection
|
||||
0xC0, # End Collection
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Default protocol mode set to report protocol
|
||||
protocol_mode = Message.ProtocolMode.REPORT_PROTOCOL
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def sdp_records():
|
||||
service_record_handle = 0x00010002
|
||||
return {
|
||||
service_record_handle: [
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(service_record_handle),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE)]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_16(HID_CONTROL_PSM),
|
||||
]
|
||||
),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HIDP_PROTOCOL_ID),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer_16(LANGUAGE),
|
||||
DataElement.unsigned_integer_16(ENCODING),
|
||||
DataElement.unsigned_integer_16(PRIMARY_LANGUAGE_BASE_ID),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE),
|
||||
DataElement.unsigned_integer_16(VERSION_NUMBER),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_16(
|
||||
HID_INTERRUPT_PSM
|
||||
),
|
||||
]
|
||||
),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HIDP_PROTOCOL_ID),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SERVICE_NAME_ATTRIBUTE_ID,
|
||||
DataElement(DataElement.TEXT_STRING, SERVICE_NAME),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID,
|
||||
DataElement(DataElement.TEXT_STRING, SERVICE_DESCRIPTION),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID,
|
||||
DataElement(DataElement.TEXT_STRING, PROVIDER_NAME),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_PARSER_VERSION_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(HID_PARSER_VERSION),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(HID_DEVICE_SUBCLASS),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(HID_COUNTRY_CODE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_VIRTUAL_CABLE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_RECONNECT_INITIATE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer_16(REPORT_DESCRIPTOR_TYPE),
|
||||
DataElement(DataElement.TEXT_STRING, HID_REPORT_MAP),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer_16(
|
||||
HID_LANGID_BASE_LANGUAGE
|
||||
),
|
||||
DataElement.unsigned_integer_16(
|
||||
HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_BATTERY_POWER_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_BATTERY_POWER),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_REMOTE_WAKE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_16(HID_SUPERVISION_TIMEOUT),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_NORMALLY_CONNECTABLE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_BOOT_DEVICE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_16(HID_SSR_HOST_MAX_LATENCY),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_16(HID_SSR_HOST_MIN_TIMEOUT),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_stream_reader(pipe) -> asyncio.StreamReader:
|
||||
loop = asyncio.get_event_loop()
|
||||
reader = asyncio.StreamReader(loop=loop)
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
await loop.connect_read_pipe(lambda: protocol, pipe)
|
||||
return reader
|
||||
|
||||
|
||||
class DeviceData:
|
||||
def __init__(self) -> None:
|
||||
self.keyboardData = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
self.mouseData = bytearray([0x02, 0x00, 0x00, 0x00])
|
||||
|
||||
|
||||
# Device's live data - Mouse and Keyboard will be stored in this
|
||||
deviceData = DeviceData()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def keyboard_device(hid_device):
|
||||
|
||||
# Start a Websocket server to receive events from a web page
|
||||
async def serve(websocket, _path):
|
||||
global deviceData
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.recv()
|
||||
print('Received: ', str(message))
|
||||
parsed = json.loads(message)
|
||||
message_type = parsed['type']
|
||||
if message_type == 'keydown':
|
||||
# Only deal with keys a to z for now
|
||||
key = parsed['key']
|
||||
if len(key) == 1:
|
||||
code = ord(key)
|
||||
if ord('a') <= code <= ord('z'):
|
||||
hid_code = 0x04 + code - ord('a')
|
||||
deviceData.keyboardData = bytearray(
|
||||
[
|
||||
0x01,
|
||||
0x00,
|
||||
0x00,
|
||||
hid_code,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
]
|
||||
)
|
||||
hid_device.send_data(deviceData.keyboardData)
|
||||
elif message_type == 'keyup':
|
||||
deviceData.keyboardData = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
hid_device.send_data(deviceData.keyboardData)
|
||||
elif message_type == "mousemove":
|
||||
# logical min and max values
|
||||
log_min = -127
|
||||
log_max = 127
|
||||
x = parsed['x']
|
||||
y = parsed['y']
|
||||
# limiting x and y values within logical max and min range
|
||||
x = max(log_min, min(log_max, x))
|
||||
y = max(log_min, min(log_max, y))
|
||||
x_cord = x.to_bytes(signed=True)
|
||||
y_cord = y.to_bytes(signed=True)
|
||||
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
|
||||
hid_device.send_data(deviceData.mouseData)
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
pass
|
||||
|
||||
# pylint: disable-next=no-member
|
||||
await websockets.serve(serve, 'localhost', 8989)
|
||||
await asyncio.get_event_loop().create_future()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main():
|
||||
if len(sys.argv) < 3:
|
||||
print(
|
||||
'Usage: python run_hid_device.py <device-config> <transport-spec> <command>'
|
||||
' where <command> is one of:\n'
|
||||
' test-mode (run with menu enabled for testing)\n'
|
||||
' web (run a keyboard with keypress input from a web page, '
|
||||
'see keyboard.html'
|
||||
)
|
||||
print('example: python run_hid_device.py hid_keyboard.json usb:0 web')
|
||||
print('example: python run_hid_device.py hid_keyboard.json usb:0 test-mode')
|
||||
|
||||
return
|
||||
|
||||
async def handle_virtual_cable_unplug():
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
await hid_device.disconnect_interrupt_channel()
|
||||
await hid_device.disconnect_control_channel()
|
||||
await device.keystore.delete(hid_host_bd_addr) # type: ignore
|
||||
connection = hid_device.connection
|
||||
if connection is not None:
|
||||
await connection.disconnect()
|
||||
|
||||
def on_hid_data_cb(pdu: bytes):
|
||||
print(f'Received Data, PDU: {pdu.hex()}')
|
||||
|
||||
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
|
||||
retValue = hid_device.GetSetStatus()
|
||||
print(
|
||||
"GET_REPORT report_id: "
|
||||
+ str(report_id)
|
||||
+ "report_type: "
|
||||
+ str(report_type)
|
||||
+ "buffer_size:"
|
||||
+ str(buffer_size)
|
||||
)
|
||||
if report_type == Message.ReportType.INPUT_REPORT:
|
||||
if report_id == 1:
|
||||
retValue.data = deviceData.keyboardData[1:]
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
elif report_id == 2:
|
||||
retValue.data = deviceData.mouseData[1:]
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||
|
||||
if buffer_size:
|
||||
data_len = buffer_size - 1
|
||||
retValue.data = retValue.data[:data_len]
|
||||
elif report_type == Message.ReportType.OUTPUT_REPORT:
|
||||
# This sample app has nothing to do with the report received, to enable PTS
|
||||
# testing, we will return single byte random data.
|
||||
retValue.data = bytearray([0x11])
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
elif report_type == Message.ReportType.FEATURE_REPORT:
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_type == Message.ReportType.OTHER_REPORT:
|
||||
if report_id == 3:
|
||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.FAILURE
|
||||
|
||||
return retValue
|
||||
|
||||
def on_set_report_cb(
|
||||
report_id: int, report_type: int, report_size: int, data: bytes
|
||||
):
|
||||
retValue = hid_device.GetSetStatus()
|
||||
print(
|
||||
"SET_REPORT report_id: "
|
||||
+ str(report_id)
|
||||
+ "report_type: "
|
||||
+ str(report_type)
|
||||
+ "report_size "
|
||||
+ str(report_size)
|
||||
+ "data:"
|
||||
+ str(data)
|
||||
)
|
||||
if report_type == Message.ReportType.FEATURE_REPORT:
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_type == Message.ReportType.INPUT_REPORT:
|
||||
if report_id == 1 and report_size != len(deviceData.keyboardData):
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_id == 2 and report_size != len(deviceData.mouseData):
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_id == 3:
|
||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
|
||||
return retValue
|
||||
|
||||
def on_get_protocol_cb():
|
||||
retValue = hid_device.GetSetStatus()
|
||||
retValue.data = protocol_mode.to_bytes()
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
return retValue
|
||||
|
||||
def on_set_protocol_cb(protocol: int):
|
||||
retValue = hid_device.GetSetStatus()
|
||||
# We do not support SET_PROTOCOL.
|
||||
print(f"SET_PROTOCOL report_id: {protocol}")
|
||||
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
|
||||
return retValue
|
||||
|
||||
def on_virtual_cable_unplug_cb():
|
||||
print('Received Virtual Cable Unplug')
|
||||
asyncio.create_task(handle_virtual_cable_unplug())
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
# Create a device
|
||||
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
|
||||
device.classic_enabled = True
|
||||
|
||||
# Create and register HID device
|
||||
hid_device = HID_Device(device)
|
||||
|
||||
# Register for call backs
|
||||
hid_device.on('interrupt_data', on_hid_data_cb)
|
||||
|
||||
hid_device.register_get_report_cb(on_get_report_cb)
|
||||
hid_device.register_set_report_cb(on_set_report_cb)
|
||||
hid_device.register_get_protocol_cb(on_get_protocol_cb)
|
||||
hid_device.register_set_protocol_cb(on_set_protocol_cb)
|
||||
|
||||
# Register for virtual cable unplug call back
|
||||
hid_device.on('virtual_cable_unplug', on_virtual_cable_unplug_cb)
|
||||
|
||||
# Setup the SDP to advertise HID Device service
|
||||
device.sdp_service_records = sdp_records()
|
||||
|
||||
# Start the controller
|
||||
await device.power_on()
|
||||
|
||||
# Start being discoverable and connectable
|
||||
await device.set_discoverable(True)
|
||||
await device.set_connectable(True)
|
||||
|
||||
async def menu():
|
||||
reader = await get_stream_reader(sys.stdin)
|
||||
while True:
|
||||
print(
|
||||
"\n************************ HID Device Menu *****************************\n"
|
||||
)
|
||||
print(" 1. Connect Control Channel")
|
||||
print(" 2. Connect Interrupt Channel")
|
||||
print(" 3. Disconnect Control Channel")
|
||||
print(" 4. Disconnect Interrupt Channel")
|
||||
print(" 5. Send Report on Interrupt Channel")
|
||||
print(" 6. Virtual Cable Unplug")
|
||||
print(" 7. Disconnect device")
|
||||
print(" 8. Delete Bonding")
|
||||
print(" 9. Re-connect to device")
|
||||
print("10. Exit ")
|
||||
print("\nEnter your choice : \n")
|
||||
|
||||
choice = await reader.readline()
|
||||
choice = choice.decode('utf-8').strip()
|
||||
|
||||
if choice == '1':
|
||||
await hid_device.connect_control_channel()
|
||||
|
||||
elif choice == '2':
|
||||
await hid_device.connect_interrupt_channel()
|
||||
|
||||
elif choice == '3':
|
||||
await hid_device.disconnect_control_channel()
|
||||
|
||||
elif choice == '4':
|
||||
await hid_device.disconnect_interrupt_channel()
|
||||
|
||||
elif choice == '5':
|
||||
print(" 1. Report ID 0x01")
|
||||
print(" 2. Report ID 0x02")
|
||||
print(" 3. Invalid Report ID")
|
||||
|
||||
choice1 = await reader.readline()
|
||||
choice1 = choice1.decode('utf-8').strip()
|
||||
|
||||
if choice1 == '1':
|
||||
data = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
hid_device.send_data(data)
|
||||
data = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
hid_device.send_data(data)
|
||||
|
||||
elif choice1 == '2':
|
||||
data = bytearray([0x02, 0x00, 0x00, 0xF6])
|
||||
hid_device.send_data(data)
|
||||
data = bytearray([0x02, 0x00, 0x00, 0x00])
|
||||
hid_device.send_data(data)
|
||||
|
||||
elif choice1 == '3':
|
||||
data = bytearray([0x00, 0x00, 0x00, 0x00])
|
||||
hid_device.send_data(data)
|
||||
data = bytearray([0x00, 0x00, 0x00, 0x00])
|
||||
hid_device.send_data(data)
|
||||
|
||||
else:
|
||||
print('Incorrect option selected')
|
||||
|
||||
elif choice == '6':
|
||||
hid_device.virtual_cable_unplug()
|
||||
try:
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
await device.keystore.delete(hid_host_bd_addr)
|
||||
except KeyError:
|
||||
print('Device not found or Device already unpaired.')
|
||||
|
||||
elif choice == '7':
|
||||
connection = hid_device.connection
|
||||
if connection is not None:
|
||||
await connection.disconnect()
|
||||
else:
|
||||
print("Already disconnected from device")
|
||||
|
||||
elif choice == '8':
|
||||
try:
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
await device.keystore.delete(hid_host_bd_addr)
|
||||
except KeyError:
|
||||
print('Device NOT found or Device already unpaired.')
|
||||
|
||||
elif choice == '9':
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
connection = await device.connect(
|
||||
hid_host_bd_addr, transport=BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
await connection.authenticate()
|
||||
await connection.encrypt()
|
||||
|
||||
elif choice == '10':
|
||||
sys.exit("Exit successful")
|
||||
|
||||
else:
|
||||
print("Invalid option selected.")
|
||||
|
||||
if (len(sys.argv) > 3) and (sys.argv[3] == 'test-mode'):
|
||||
# Test mode for PTS/Unit testing
|
||||
await menu()
|
||||
else:
|
||||
# default option is using keyboard.html (web)
|
||||
print("Executing in Web mode")
|
||||
await keyboard_device(hid_device)
|
||||
|
||||
await hci_source.wait_for_termination()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
+51
-20
@@ -285,7 +285,10 @@ async def main():
|
||||
print('example: run_hid_host.py classic1.json usb:0 E1:CA:72:48:C4:E8/P')
|
||||
return
|
||||
|
||||
def on_hid_data_cb(pdu):
|
||||
def on_hid_control_data_cb(pdu: bytes):
|
||||
print(f'Received Control Data, PDU: {pdu.hex()}')
|
||||
|
||||
def on_hid_interrupt_data_cb(pdu: bytes):
|
||||
report_type = pdu[0] & 0x0F
|
||||
if len(pdu) == 1:
|
||||
print(color(f'Warning: No report received', 'yellow'))
|
||||
@@ -305,7 +308,7 @@ async def main():
|
||||
|
||||
if (report_length <= 1) or (report_id == 0):
|
||||
return
|
||||
|
||||
# Parse report over interrupt channel
|
||||
if report_type == Message.ReportType.INPUT_REPORT:
|
||||
ReportParser.parse_input_report(pdu[1:]) # type: ignore
|
||||
|
||||
@@ -313,7 +316,9 @@ async def main():
|
||||
await hid_host.disconnect_interrupt_channel()
|
||||
await hid_host.disconnect_control_channel()
|
||||
await device.keystore.delete(target_address) # type: ignore
|
||||
await connection.disconnect()
|
||||
connection = hid_host.connection
|
||||
if connection is not None:
|
||||
await connection.disconnect()
|
||||
|
||||
def on_hid_virtual_cable_unplug_cb():
|
||||
asyncio.create_task(handle_virtual_cable_unplug())
|
||||
@@ -325,6 +330,18 @@ async def main():
|
||||
# Create a device
|
||||
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
|
||||
device.classic_enabled = True
|
||||
|
||||
# Create HID host and start it
|
||||
print('@@@ Starting HID Host...')
|
||||
hid_host = Host(device)
|
||||
|
||||
# Register for HID data call back
|
||||
hid_host.on('interrupt_data', on_hid_interrupt_data_cb)
|
||||
hid_host.on('control_data', on_hid_control_data_cb)
|
||||
|
||||
# Register for virtual cable unplug call back
|
||||
hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb)
|
||||
|
||||
await device.power_on()
|
||||
|
||||
# Connect to a peer
|
||||
@@ -345,16 +362,6 @@ async def main():
|
||||
|
||||
await get_hid_device_sdp_record(connection)
|
||||
|
||||
# Create HID host and start it
|
||||
print('@@@ Starting HID Host...')
|
||||
hid_host = Host(device, connection)
|
||||
|
||||
# Register for HID data call back
|
||||
hid_host.on('data', on_hid_data_cb)
|
||||
|
||||
# Register for virtual cable unplug call back
|
||||
hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb)
|
||||
|
||||
async def menu():
|
||||
reader = await get_stream_reader(sys.stdin)
|
||||
while True:
|
||||
@@ -369,13 +376,14 @@ async def main():
|
||||
print(" 6. Set Report")
|
||||
print(" 7. Set Protocol Mode")
|
||||
print(" 8. Get Protocol Mode")
|
||||
print(" 9. Send Report")
|
||||
print(" 9. Send Report on Interrupt Channel")
|
||||
print("10. Suspend")
|
||||
print("11. Exit Suspend")
|
||||
print("12. Virtual Cable Unplug")
|
||||
print("13. Disconnect device")
|
||||
print("14. Delete Bonding")
|
||||
print("15. Re-connect to device")
|
||||
print("16. Exit")
|
||||
print("\nEnter your choice : \n")
|
||||
|
||||
choice = await reader.readline()
|
||||
@@ -394,21 +402,40 @@ async def main():
|
||||
await hid_host.disconnect_interrupt_channel()
|
||||
|
||||
elif choice == '5':
|
||||
print(" 1. Report ID 0x02")
|
||||
print(" 2. Report ID 0x03")
|
||||
print(" 3. Report ID 0x05")
|
||||
print(" 1. Input Report with ID 0x01")
|
||||
print(" 2. Input Report with ID 0x02")
|
||||
print(" 3. Input Report with ID 0x0F - Invalid ReportId")
|
||||
print(" 4. Output Report with ID 0x02")
|
||||
print(" 5. Feature Report with ID 0x05 - Unsupported Request")
|
||||
print(" 6. Input Report with ID 0x02, BufferSize 3")
|
||||
print(" 7. Output Report with ID 0x03, BufferSize 2")
|
||||
print(" 8. Feature Report with ID 0x05, BufferSize 3")
|
||||
choice1 = await reader.readline()
|
||||
choice1 = choice1.decode('utf-8').strip()
|
||||
|
||||
if choice1 == '1':
|
||||
hid_host.get_report(1, 2, 3)
|
||||
hid_host.get_report(1, 1, 0)
|
||||
|
||||
elif choice1 == '2':
|
||||
hid_host.get_report(2, 3, 2)
|
||||
hid_host.get_report(1, 2, 0)
|
||||
|
||||
elif choice1 == '3':
|
||||
hid_host.get_report(3, 5, 3)
|
||||
hid_host.get_report(1, 5, 0)
|
||||
|
||||
elif choice1 == '4':
|
||||
hid_host.get_report(2, 2, 0)
|
||||
|
||||
elif choice1 == '5':
|
||||
hid_host.get_report(3, 15, 0)
|
||||
|
||||
elif choice1 == '6':
|
||||
hid_host.get_report(1, 2, 3)
|
||||
|
||||
elif choice1 == '7':
|
||||
hid_host.get_report(2, 3, 2)
|
||||
|
||||
elif choice1 == '8':
|
||||
hid_host.get_report(3, 5, 3)
|
||||
else:
|
||||
print('Incorrect option selected')
|
||||
|
||||
@@ -484,6 +511,7 @@ async def main():
|
||||
hid_host.virtual_cable_unplug()
|
||||
try:
|
||||
await device.keystore.delete(target_address)
|
||||
print("Unpair successful")
|
||||
except KeyError:
|
||||
print('Device not found or Device already unpaired.')
|
||||
|
||||
@@ -513,6 +541,9 @@ async def main():
|
||||
await connection.authenticate()
|
||||
await connection.encrypt()
|
||||
|
||||
elif choice == '16':
|
||||
sys.exit("Exit successful")
|
||||
|
||||
else:
|
||||
print("Invalid option selected.")
|
||||
|
||||
|
||||
@@ -19,12 +19,15 @@ import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import struct
|
||||
import secrets
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device
|
||||
from bumble.device import Device, CisLink
|
||||
from bumble.hci import (
|
||||
CodecID,
|
||||
CodingFormat,
|
||||
OwnAddressType,
|
||||
HCI_IsoDataPacket,
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command,
|
||||
)
|
||||
from bumble.profiles.bap import (
|
||||
@@ -35,7 +38,10 @@ from bumble.profiles.bap import (
|
||||
SupportedFrameDuration,
|
||||
PacRecord,
|
||||
PublishedAudioCapabilitiesService,
|
||||
AudioStreamControlService,
|
||||
)
|
||||
from bumble.profiles.cap import CommonAudioServiceService
|
||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
@@ -57,6 +63,11 @@ async def main() -> None:
|
||||
|
||||
await device.power_on()
|
||||
|
||||
csis = CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=secrets.token_bytes(16),
|
||||
set_identity_resolving_key_type=SirkType.PLAINTEXT,
|
||||
)
|
||||
device.add_service(CommonAudioServiceService(csis))
|
||||
device.add_service(
|
||||
PublishedAudioCapabilitiesService(
|
||||
supported_source_context=ContextType.PROHIBITED,
|
||||
@@ -103,21 +114,71 @@ async def main() -> None:
|
||||
)
|
||||
)
|
||||
|
||||
advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes('Bumble LE Audio', 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(PublishedAudioCapabilitiesService.UUID),
|
||||
),
|
||||
]
|
||||
device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2]))
|
||||
|
||||
advertising_data = (
|
||||
bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes('Bumble LE Audio', 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes(
|
||||
[
|
||||
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
|
||||
| AdvertisingData.BR_EDR_HOST_FLAG
|
||||
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(PublishedAudioCapabilitiesService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
+ csis.get_advertising_data()
|
||||
)
|
||||
subprocess = await asyncio.create_subprocess_shell(
|
||||
f'dlc3 | ffplay pipe:0',
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdin = subprocess.stdin
|
||||
assert stdin
|
||||
|
||||
# Write a fake LC3 header to dlc3.
|
||||
stdin.write(
|
||||
bytes([0x1C, 0xCC]) # Header.
|
||||
+ struct.pack(
|
||||
'<HHHHHHI',
|
||||
18, # Header length.
|
||||
24000 // 100, # Sampling Rate(/100Hz).
|
||||
0, # Bitrate(unused).
|
||||
1, # Channels.
|
||||
10000 // 10, # Frame duration(/10us).
|
||||
0, # RFU.
|
||||
0x0FFFFFFF, # Frame counts.
|
||||
)
|
||||
)
|
||||
|
||||
def on_pdu(pdu: HCI_IsoDataPacket):
|
||||
# LC3 format: |frame_length(2)| + |frame(length)|.
|
||||
if pdu.iso_sdu_length:
|
||||
stdin.write(struct.pack('<H', pdu.iso_sdu_length))
|
||||
stdin.write(pdu.iso_sdu_fragment)
|
||||
|
||||
def on_cis(cis_link: CisLink):
|
||||
cis_link.on('pdu', on_pdu)
|
||||
|
||||
device.once('cis_establishment', on_cis)
|
||||
|
||||
await device.start_extended_advertising(
|
||||
advertising_properties=(
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
|
||||
|
||||
+13
-7
@@ -28,8 +28,8 @@ private val Log = Logger.getLogger("btbench.l2cap-client")
|
||||
|
||||
class L2capClient(
|
||||
private val viewModel: AppViewModel,
|
||||
val bluetoothAdapter: BluetoothAdapter,
|
||||
val context: Context
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val context: Context
|
||||
) {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
@@ -74,12 +74,18 @@ class L2capClient(
|
||||
gatt: BluetoothGatt?, status: Int, newState: Int
|
||||
) {
|
||||
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
|
||||
gatt.setPreferredPhy(
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_OPTION_NO_PREFERRED
|
||||
)
|
||||
if (viewModel.use2mPhy) {
|
||||
gatt.setPreferredPhy(
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_OPTION_NO_PREFERRED
|
||||
)
|
||||
}
|
||||
gatt.readPhy()
|
||||
|
||||
// Request an MTU update, even though we don't use GATT, because Android
|
||||
// won't request a larger link layer maximum data length otherwise.
|
||||
gatt.requestMtu(517)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+34
-33
@@ -23,19 +23,20 @@ import androidx.compose.runtime.setValue
|
||||
import androidx.lifecycle.ViewModel
|
||||
import java.util.UUID
|
||||
|
||||
val DEFAULT_RFCOMM_UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
|
||||
val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
|
||||
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
|
||||
const val DEFAULT_SENDER_PACKET_COUNT = 100
|
||||
const val DEFAULT_SENDER_PACKET_SIZE = 1024
|
||||
const val DEFAULT_PSM = 128
|
||||
|
||||
class AppViewModel : ViewModel() {
|
||||
private var preferences: SharedPreferences? = null
|
||||
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
|
||||
var l2capPsm by mutableStateOf(0)
|
||||
var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
|
||||
var use2mPhy by mutableStateOf(true)
|
||||
var mtu by mutableStateOf(0)
|
||||
var rxPhy by mutableStateOf(0)
|
||||
var txPhy by mutableStateOf(0)
|
||||
var mtu by mutableIntStateOf(0)
|
||||
var rxPhy by mutableIntStateOf(0)
|
||||
var txPhy by mutableIntStateOf(0)
|
||||
var senderPacketCountSlider by mutableFloatStateOf(0.0F)
|
||||
var senderPacketSizeSlider by mutableFloatStateOf(0.0F)
|
||||
var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT)
|
||||
@@ -79,18 +80,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketCountSlider() {
|
||||
if (senderPacketCount <= 10) {
|
||||
senderPacketCountSlider = 0.0F
|
||||
senderPacketCountSlider = if (senderPacketCount <= 10) {
|
||||
0.0F
|
||||
} else if (senderPacketCount <= 50) {
|
||||
senderPacketCountSlider = 0.2F
|
||||
0.2F
|
||||
} else if (senderPacketCount <= 100) {
|
||||
senderPacketCountSlider = 0.4F
|
||||
0.4F
|
||||
} else if (senderPacketCount <= 500) {
|
||||
senderPacketCountSlider = 0.6F
|
||||
0.6F
|
||||
} else if (senderPacketCount <= 1000) {
|
||||
senderPacketCountSlider = 0.8F
|
||||
0.8F
|
||||
} else {
|
||||
senderPacketCountSlider = 1.0F
|
||||
1.0F
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
@@ -100,18 +101,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketCount() {
|
||||
if (senderPacketCountSlider < 0.1F) {
|
||||
senderPacketCount = 10
|
||||
senderPacketCount = if (senderPacketCountSlider < 0.1F) {
|
||||
10
|
||||
} else if (senderPacketCountSlider < 0.3F) {
|
||||
senderPacketCount = 50
|
||||
50
|
||||
} else if (senderPacketCountSlider < 0.5F) {
|
||||
senderPacketCount = 100
|
||||
100
|
||||
} else if (senderPacketCountSlider < 0.7F) {
|
||||
senderPacketCount = 500
|
||||
500
|
||||
} else if (senderPacketCountSlider < 0.9F) {
|
||||
senderPacketCount = 1000
|
||||
1000
|
||||
} else {
|
||||
senderPacketCount = 10000
|
||||
10000
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
@@ -121,18 +122,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketSizeSlider() {
|
||||
if (senderPacketSize <= 16) {
|
||||
senderPacketSizeSlider = 0.0F
|
||||
senderPacketSizeSlider = if (senderPacketSize <= 16) {
|
||||
0.0F
|
||||
} else if (senderPacketSize <= 256) {
|
||||
senderPacketSizeSlider = 0.02F
|
||||
0.02F
|
||||
} else if (senderPacketSize <= 512) {
|
||||
senderPacketSizeSlider = 0.4F
|
||||
0.4F
|
||||
} else if (senderPacketSize <= 1024) {
|
||||
senderPacketSizeSlider = 0.6F
|
||||
0.6F
|
||||
} else if (senderPacketSize <= 2048) {
|
||||
senderPacketSizeSlider = 0.8F
|
||||
0.8F
|
||||
} else {
|
||||
senderPacketSizeSlider = 1.0F
|
||||
1.0F
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
@@ -142,18 +143,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketSize() {
|
||||
if (senderPacketSizeSlider < 0.1F) {
|
||||
senderPacketSize = 16
|
||||
senderPacketSize = if (senderPacketSizeSlider < 0.1F) {
|
||||
16
|
||||
} else if (senderPacketSizeSlider < 0.3F) {
|
||||
senderPacketSize = 256
|
||||
256
|
||||
} else if (senderPacketSizeSlider < 0.5F) {
|
||||
senderPacketSize = 512
|
||||
512
|
||||
} else if (senderPacketSizeSlider < 0.7F) {
|
||||
senderPacketSize = 1024
|
||||
1024
|
||||
} else if (senderPacketSizeSlider < 0.9F) {
|
||||
senderPacketSize = 2048
|
||||
2048
|
||||
} else {
|
||||
senderPacketSize = 4096
|
||||
4096
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
|
||||
+1
@@ -42,6 +42,7 @@ public class HciServer {
|
||||
try (ServerSocket serverSocket = new ServerSocket(mPort)) {
|
||||
mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort());
|
||||
try (Socket clientSocket = serverSocket.accept()) {
|
||||
clientSocket.setTcpNoDelay(true);
|
||||
mListener.onHostConnectionState(true);
|
||||
mListener.onMessage("Connected");
|
||||
HciParser parser = new HciParser(mListener);
|
||||
|
||||
@@ -56,6 +56,7 @@ install_requires =
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
bumble-ble-rpa-tool = bumble.apps.ble_rpa_tool:main
|
||||
bumble-console = bumble.apps.console:main
|
||||
bumble-controller-info = bumble.apps.controller_info:main
|
||||
bumble-gatt-dump = bumble.apps.gatt_dump:main
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import functools
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
@@ -24,11 +25,26 @@ from bumble import device
|
||||
from bumble.hci import CodecID, CodingFormat
|
||||
from bumble.profiles.bap import (
|
||||
AudioLocation,
|
||||
AseStateMachine,
|
||||
ASE_Operation,
|
||||
ASE_Config_Codec,
|
||||
ASE_Config_QOS,
|
||||
ASE_Disable,
|
||||
ASE_Enable,
|
||||
ASE_Receiver_Start_Ready,
|
||||
ASE_Receiver_Stop_Ready,
|
||||
ASE_Release,
|
||||
ASE_Update_Metadata,
|
||||
SupportedFrameDuration,
|
||||
SupportedSamplingFrequency,
|
||||
SamplingFrequency,
|
||||
FrameDuration,
|
||||
CodecSpecificCapabilities,
|
||||
CodecSpecificConfiguration,
|
||||
ContextType,
|
||||
PacRecord,
|
||||
AudioStreamControlService,
|
||||
AudioStreamControlServiceProxy,
|
||||
PublishedAudioCapabilitiesService,
|
||||
PublishedAudioCapabilitiesServiceProxy,
|
||||
)
|
||||
@@ -40,6 +56,13 @@ from .test_utils import TwoDevices
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def basic_check(operation: ASE_Operation):
|
||||
serialized = bytes(operation)
|
||||
parsed = ASE_Operation.from_bytes(serialized)
|
||||
assert bytes(parsed) == serialized
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_codec_specific_capabilities() -> None:
|
||||
SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
|
||||
@@ -85,6 +108,92 @@ def test_vendor_specific_pac_record() -> None:
|
||||
assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Config_Codec() -> None:
|
||||
operation = ASE_Config_Codec(
|
||||
ase_id=[1, 2],
|
||||
target_latency=[3, 4],
|
||||
target_phy=[5, 6],
|
||||
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
|
||||
codec_specific_configuration=[b'foo', b'bar'],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Config_QOS() -> None:
|
||||
operation = ASE_Config_QOS(
|
||||
ase_id=[1, 2],
|
||||
cig_id=[1, 2],
|
||||
cis_id=[3, 4],
|
||||
sdu_interval=[5, 6],
|
||||
framing=[0, 1],
|
||||
phy=[2, 3],
|
||||
max_sdu=[4, 5],
|
||||
retransmission_number=[6, 7],
|
||||
max_transport_latency=[8, 9],
|
||||
presentation_delay=[10, 11],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Enable() -> None:
|
||||
operation = ASE_Enable(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Update_Metadata() -> None:
|
||||
operation = ASE_Update_Metadata(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Disable() -> None:
|
||||
operation = ASE_Disable(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Release() -> None:
|
||||
operation = ASE_Release(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Receiver_Start_Ready() -> None:
|
||||
operation = ASE_Receiver_Start_Ready(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Receiver_Stop_Ready() -> None:
|
||||
operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_codec_specific_configuration() -> None:
|
||||
SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000
|
||||
FRAME_SURATION = FrameDuration.DURATION_10000_US
|
||||
AUDIO_LOCATION = AudioLocation.FRONT_LEFT
|
||||
config = CodecSpecificConfiguration(
|
||||
sampling_frequency=SAMPLE_FREQUENCY,
|
||||
frame_duration=FRAME_SURATION,
|
||||
audio_channel_allocation=AUDIO_LOCATION,
|
||||
octets_per_codec_frame=60,
|
||||
codec_frames_per_sdu=1,
|
||||
)
|
||||
assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_pacs():
|
||||
@@ -140,6 +249,148 @@ async def test_pacs():
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascs():
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
peer = device.Peer(devices.connections[1])
|
||||
ascs_client = await peer.discover_service_and_create_proxy(
|
||||
AudioStreamControlServiceProxy
|
||||
)
|
||||
|
||||
notifications = {1: asyncio.Queue(), 2: asyncio.Queue()}
|
||||
|
||||
def on_notification(data: bytes, ase_id: int):
|
||||
notifications[ase_id].put_nowait(data)
|
||||
|
||||
# Should be idle
|
||||
assert await ascs_client.sink_ase[0].read_value() == bytes(
|
||||
[1, AseStateMachine.State.IDLE]
|
||||
)
|
||||
assert await ascs_client.sink_ase[1].read_value() == bytes(
|
||||
[2, AseStateMachine.State.IDLE]
|
||||
)
|
||||
|
||||
# Subscribe
|
||||
await ascs_client.sink_ase[0].subscribe(
|
||||
functools.partial(on_notification, ase_id=1)
|
||||
)
|
||||
await ascs_client.sink_ase[1].subscribe(
|
||||
functools.partial(on_notification, ase_id=2)
|
||||
)
|
||||
|
||||
# Config Codec
|
||||
config = CodecSpecificConfiguration(
|
||||
sampling_frequency=SamplingFrequency.FREQ_48000,
|
||||
frame_duration=FrameDuration.DURATION_10000_US,
|
||||
audio_channel_allocation=AudioLocation.FRONT_LEFT,
|
||||
octets_per_codec_frame=120,
|
||||
codec_frames_per_sdu=1,
|
||||
)
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Config_Codec(
|
||||
ase_id=[1, 2],
|
||||
target_latency=[3, 4],
|
||||
target_phy=[5, 6],
|
||||
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
|
||||
codec_specific_configuration=[config, config],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.CODEC_CONFIGURED]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.CODEC_CONFIGURED]
|
||||
)
|
||||
|
||||
# Config QOS
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Config_QOS(
|
||||
ase_id=[1, 2],
|
||||
cig_id=[1, 2],
|
||||
cis_id=[3, 4],
|
||||
sdu_interval=[5, 6],
|
||||
framing=[0, 1],
|
||||
phy=[2, 3],
|
||||
max_sdu=[4, 5],
|
||||
retransmission_number=[6, 7],
|
||||
max_transport_latency=[8, 9],
|
||||
presentation_delay=[10, 11],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.QOS_CONFIGURED]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.QOS_CONFIGURED]
|
||||
)
|
||||
|
||||
# Enable
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Enable(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.ENABLING]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.ENABLING]
|
||||
)
|
||||
|
||||
# CIS establishment
|
||||
devices[0].emit(
|
||||
'cis_establishment',
|
||||
device.CisLink(
|
||||
device=devices[0],
|
||||
acl_connection=devices.connections[0],
|
||||
handle=5,
|
||||
cis_id=3,
|
||||
cig_id=1,
|
||||
),
|
||||
)
|
||||
devices[0].emit(
|
||||
'cis_establishment',
|
||||
device.CisLink(
|
||||
device=devices[0],
|
||||
acl_connection=devices.connections[0],
|
||||
handle=6,
|
||||
cis_id=4,
|
||||
cig_id=2,
|
||||
),
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.STREAMING]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.STREAMING]
|
||||
)
|
||||
|
||||
# Release
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Release(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.RELEASING]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.RELEASING]
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE])
|
||||
assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE])
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
await test_pacs()
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble.profiles import cap
|
||||
from bumble.profiles import csip
|
||||
from .test_utils import TwoDevices
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_cas():
|
||||
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
|
||||
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
cap.CommonAudioServiceService(
|
||||
csip.CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=SIRK,
|
||||
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
peer = device.Peer(devices.connections[1])
|
||||
cas_client = await peer.discover_service_and_create_proxy(
|
||||
cap.CommonAudioServiceServiceProxy
|
||||
)
|
||||
|
||||
included_services = await peer.discover_included_services(cas_client.service_proxy)
|
||||
assert any(
|
||||
service.uuid == gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
for service in included_services
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
await test_cas()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run())
|
||||
@@ -31,6 +31,41 @@ from .test_utils import TwoDevices
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_s1():
|
||||
assert (
|
||||
csip.s1(b'SIRKenc'[::-1])
|
||||
== bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_k1():
|
||||
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
|
||||
SALT = csip.s1(b'SIRKenc'[::-1])
|
||||
P = b'csis'[::-1]
|
||||
assert (
|
||||
csip.k1(K, SALT, P)
|
||||
== bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_sih():
|
||||
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
|
||||
PRAND = bytes.fromhex('69f563')[::-1]
|
||||
assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_sef():
|
||||
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
|
||||
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
|
||||
assert (
|
||||
csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_csis():
|
||||
@@ -40,6 +75,7 @@ async def test_csis():
|
||||
devices[0].add_service(
|
||||
csip.CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=SIRK,
|
||||
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
|
||||
coordinated_set_size=2,
|
||||
set_member_lock=csip.MemberLock.UNLOCKED,
|
||||
set_member_rank=0,
|
||||
@@ -65,6 +101,7 @@ async def test_csis():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
test_sih()
|
||||
await test_csis()
|
||||
|
||||
|
||||
|
||||
+182
-2
@@ -20,16 +20,23 @@ import logging
|
||||
import os
|
||||
from types import LambdaType
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from bumble.core import BT_BR_EDR_TRANSPORT
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters,
|
||||
)
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.host import Host
|
||||
from bumble.host import AclPacketQueue, Host
|
||||
from bumble.hci import (
|
||||
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
|
||||
HCI_COMMAND_STATUS_PENDING,
|
||||
HCI_CREATE_CONNECTION_COMMAND,
|
||||
HCI_SUCCESS,
|
||||
Address,
|
||||
OwnAddressType,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_Connection_Complete_Event,
|
||||
@@ -66,6 +73,13 @@ async def test_device_connect_parallel():
|
||||
d1 = Device(host=Host(None, None))
|
||||
d2 = Device(host=Host(None, None))
|
||||
|
||||
def _send(packet):
|
||||
pass
|
||||
|
||||
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
|
||||
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
|
||||
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
|
||||
|
||||
# enable classic
|
||||
d0.classic_enabled = True
|
||||
d1.classic_enabled = True
|
||||
@@ -232,6 +246,172 @@ async def test_flush():
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising():
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
|
||||
# Start advertising
|
||||
advertiser = await device.start_legacy_advertising()
|
||||
assert device.legacy_advertiser
|
||||
|
||||
# Stop advertising
|
||||
await advertiser.stop()
|
||||
assert not device.legacy_advertiser
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising_connection(own_address_type):
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
|
||||
# Start advertising
|
||||
advertiser = await device.start_legacy_advertising()
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
|
||||
if own_address_type == OwnAddressType.PUBLIC:
|
||||
assert device.lookup_connection(0x0001).self_address == device.public_address
|
||||
else:
|
||||
assert device.lookup_connection(0x0001).self_address == device.random_address
|
||||
|
||||
# For unknown reason, read_phy() in on_connection() would be killed at the end of
|
||||
# test, so we force scheduling here to avoid an warning.
|
||||
await asyncio.sleep(0.0001)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
(True, False),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising_disconnection(auto_restart):
|
||||
device = Device(host=mock.AsyncMock(spec=Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
advertiser = await device.start_legacy_advertising(auto_restart=auto_restart)
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
|
||||
device.start_legacy_advertising = mock.AsyncMock()
|
||||
|
||||
device.on_disconnection(0x0001, 0)
|
||||
|
||||
if auto_restart:
|
||||
device.start_legacy_advertising.assert_called_with(
|
||||
advertising_type=advertiser.advertising_type,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
auto_restart=advertiser.auto_restart,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
)
|
||||
else:
|
||||
device.start_legacy_advertising.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_advertising():
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
|
||||
# Start advertising
|
||||
advertiser = await device.start_extended_advertising()
|
||||
assert device.extended_advertisers
|
||||
|
||||
# Stop advertising
|
||||
await advertiser.stop()
|
||||
assert not device.extended_advertisers
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_advertising_connection(own_address_type):
|
||||
device = Device(host=mock.AsyncMock(spec=Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
advertiser = await device.start_extended_advertising(
|
||||
own_address_type=own_address_type
|
||||
)
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
device.on_advertising_set_termination(
|
||||
HCI_SUCCESS,
|
||||
advertiser.handle,
|
||||
0x0001,
|
||||
)
|
||||
|
||||
if own_address_type == OwnAddressType.PUBLIC:
|
||||
assert device.lookup_connection(0x0001).self_address == device.public_address
|
||||
else:
|
||||
assert device.lookup_connection(0x0001).self_address == device.random_address
|
||||
|
||||
# For unknown reason, read_phy() in on_connection() would be killed at the end of
|
||||
# test, so we force scheduling here to avoid an warning.
|
||||
await asyncio.sleep(0.0001)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
(True, False),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_advertising_disconnection(auto_restart):
|
||||
device = Device(host=mock.AsyncMock(spec=Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
advertiser = await device.start_extended_advertising(auto_restart=auto_restart)
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
device.on_advertising_set_termination(
|
||||
HCI_SUCCESS,
|
||||
advertiser.handle,
|
||||
0x0001,
|
||||
)
|
||||
|
||||
device.start_extended_advertising = mock.AsyncMock()
|
||||
|
||||
device.on_disconnection(0x0001, 0)
|
||||
|
||||
if auto_restart:
|
||||
device.start_extended_advertising.assert_called_with(
|
||||
advertising_properties=advertiser.advertising_properties,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
auto_restart=advertiser.auto_restart,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
)
|
||||
else:
|
||||
device.start_extended_advertising.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_gatt_services_with_gas():
|
||||
device = Device(host=Host(None, None))
|
||||
|
||||
Reference in New Issue
Block a user