diff --git a/bumble/device.py b/bumble/device.py index b2b535d9..8d1d1019 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -18,6 +18,7 @@ import json import asyncio import logging +from contextlib import asynccontextmanager, AsyncExitStack from .hci import * from .host import Host @@ -148,10 +149,24 @@ class Peer: await service.discover_characteristics() return self.create_service_proxy(proxy_class) + async def sustain(self, timeout=None): + await self.connection.sustain(timeout) + # [Classic only] async def request_name(self): return await self.connection.request_remote_name() + async def __aenter__(self): + await self.discover_services() + for service in self.services: + await self.discover_characteristics() + + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + def __str__(self): return f'{self.connection.peer_address} as {self.connection.role_name}' @@ -232,6 +247,21 @@ class Connection(CompositeEventEmitter): async def encrypt(self): return await self.device.encrypt(self) + async def sustain(self, timeout=None): + """ Idles the current task waiting for a disconnect or timeout """ + + abort = asyncio.get_running_loop().create_future() + self.on('disconnection', abort.set_result) + self.on('disconnection_failure', abort.set_exception) + + try: + await asyncio.wait_for(abort, timeout) + except asyncio.TimeoutError: + pass + + self.remove_listener('disconnection', abort.set_result) + self.remove_listener('disconnection_failure', abort.set_exception) + async def update_parameters( self, conn_interval_min, @@ -251,6 +281,18 @@ class Connection(CompositeEventEmitter): async def request_remote_name(self): return await self.device.request_remote_name(self) + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type is None: + try: + await self.disconnect() + except HCI_StatusError as e: + # Invalid parameter means the connection is no longer valid + if e.error_code != HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR: + raise + def __str__(self): return f'Connection(handle=0x{self.handle:04X}, role={self.role_name}, address={self.peer_address})' @@ -705,7 +747,7 @@ class Device(CompositeEventEmitter): )) if response.status != HCI_Command_Status_Event.PENDING: self.discovering = False - raise RuntimeError(f'HCI_Inquiry command failed: {HCI_Constant.status_name(response.status)} ({response.status})') + raise HCI_StatusError(response) self.discovering = True @@ -825,16 +867,25 @@ class Device(CompositeEventEmitter): try: if result.status != HCI_Command_Status_Event.PENDING: - raise RuntimeError(f'HCI_LE_Create_Connection_Command failed: {HCI_Constant.status_name(result.status)} ({result.status})') + raise HCI_StatusError(result) # Wait for the connection process to complete self.connecting = True return await pending_connection + finally: self.remove_listener('connection', pending_connection.set_result) self.remove_listener('connection_failure', pending_connection.set_exception) self.connecting = False + @asynccontextmanager + async def connect_as_gatt(self, peer_address): + async with AsyncExitStack() as stack: + connection = await stack.enter_async_context(await self.connect(peer_address)) + peer = await stack.enter_async_context(Peer(connection)) + + yield peer + @property def is_connecting(self): return self.connecting @@ -859,7 +910,7 @@ class Device(CompositeEventEmitter): try: if result.status != HCI_Command_Status_Event.PENDING: - raise RuntimeError(f'HCI_Disconnect_Command failed: {HCI_Constant.status_name(result.status)} ({result.status})') + raise HCI_StatusError(result) # Wait for the disconnection process to complete self.disconnecting = True @@ -1011,7 +1062,7 @@ class Device(CompositeEventEmitter): ) if result.status != HCI_COMMAND_STATUS_PENDING: logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}') - raise HCI_Error(result.status) + raise HCI_StatusError(result) # Wait for the authentication to complete await pending_authentication @@ -1068,7 +1119,7 @@ class Device(CompositeEventEmitter): if result.status != HCI_COMMAND_STATUS_PENDING: logger.warn(f'HCI_LE_Enable_Encryption_Command failed: {HCI_Constant.error_name(result.status)}') - raise HCI_Error(result.status) + raise HCI_StatusError(result) else: result = await self.send_command( HCI_Set_Connection_Encryption_Command( @@ -1079,7 +1130,7 @@ class Device(CompositeEventEmitter): if result.status != HCI_COMMAND_STATUS_PENDING: logger.warn(f'HCI_Set_Connection_Encryption_Command failed: {HCI_Constant.error_name(result.status)}') - raise HCI_Error(result.status) + raise HCI_StatusError(result) # Wait for the result await pending_encryption @@ -1113,7 +1164,7 @@ class Device(CompositeEventEmitter): if result.status != HCI_COMMAND_STATUS_PENDING: logger.warn(f'HCI_Set_Connection_Encryption_Command failed: {HCI_Constant.error_name(result.status)}') - raise HCI_Error(result.status) + raise HCI_StatusError(result) # Wait for the result return await pending_name diff --git a/bumble/hci.py b/bumble/hci.py index 529eeb1a..bbac8430 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1373,6 +1373,13 @@ class HCI_Error(ProtocolError): super().__init__(error_code, 'hci', HCI_Constant.error_name(error_code)) +class HCI_StatusError(ProtocolError): + def __init__(self, response): + super().__init__(response.status, + error_namespace=HCI_Command.command_name(response.command_opcode), + error_name=HCI_Constant.status_name(response.status)) + + # ----------------------------------------------------------------------------- # Generic HCI object # ----------------------------------------------------------------------------- diff --git a/examples/battery_client.py b/examples/battery_client.py index 297e9f42..f545f129 100644 --- a/examples/battery_client.py +++ b/examples/battery_client.py @@ -43,28 +43,24 @@ async def main(): # Connect to the peer target_address = sys.argv[2] print(f'=== Connecting to {target_address}...') - connection = await device.connect(target_address) - print(f'=== Connected to {connection}') + async with device.connect_as_gatt(target_address) as peer: + print(f'=== Connected to {peer}') + battery_service = peer.create_service_proxy(BatteryServiceProxy) - # Discover the Battery Service - peer = Peer(connection) - print('=== Discovering Battery Service') - battery_service = await peer.discover_service_and_create_proxy(BatteryServiceProxy) + # Check that the service was found + if not battery_service: + print('!!! Service not found') + return - # Check that the service was found - if not battery_service: - print('!!! Service not found') - return + # Subscribe to and read the battery level + if battery_service.battery_level: + await battery_service.battery_level.subscribe( + lambda value: print(f'{color("Battery Level Update:", "green")} {value}') + ) + value = await battery_service.battery_level.read_value() + print(f'{color("Initial Battery Level:", "green")} {value}') - # Subscribe to and read the battery level - if battery_service.battery_level: - await battery_service.battery_level.subscribe( - lambda value: print(f'{color("Battery Level Update:", "green")} {value}') - ) - value = await battery_service.battery_level.read_value() - print(f'{color("Initial Battery Level:", "green")} {value}') - - await hci_source.wait_for_termination() + await peer.sustain() # ----------------------------------------------------------------------------- diff --git a/examples/heart_rate_client.py b/examples/heart_rate_client.py index c58dd983..130aac0c 100644 --- a/examples/heart_rate_client.py +++ b/examples/heart_rate_client.py @@ -43,31 +43,28 @@ async def main(): # Connect to the peer target_address = sys.argv[2] print(f'=== Connecting to {target_address}...') - connection = await device.connect(target_address) - print(f'=== Connected to {connection}') + async with device.connect_as_gatt(target_address) as peer: + print(f'=== Connected to {peer}') - # Discover the Heart Rate Service - peer = Peer(connection) - print('=== Discovering Heart Rate Service') - heart_rate_service = await peer.discover_service_and_create_proxy(HeartRateServiceProxy) + heart_rate_service = peer.create_service_proxy(HeartRateServiceProxy) - # Check that the service was found - if not heart_rate_service: - print('!!! Service not found') - return + # Check that the service was found + if not heart_rate_service: + print('!!! Service not found') + return - # Read the body sensor location - if heart_rate_service.body_sensor_location: - location = await heart_rate_service.body_sensor_location.read_value() - print(color('Sensor Location:', 'green'), location) + # Read the body sensor location + if heart_rate_service.body_sensor_location: + location = await heart_rate_service.body_sensor_location.read_value() + print(color('Sensor Location:', 'green'), location) - # Subscribe to the heart rate measurement - if heart_rate_service.heart_rate_measurement: - await heart_rate_service.heart_rate_measurement.subscribe( - lambda value: print(f'{color("Heart Rate Measurement:", "green")} {value}') - ) + # Subscribe to the heart rate measurement + if heart_rate_service.heart_rate_measurement: + await heart_rate_service.heart_rate_measurement.subscribe( + lambda value: print(f'{color("Heart Rate Measurement:", "green")} {value}') + ) - await hci_source.wait_for_termination() + await peer.sustain() # -----------------------------------------------------------------------------