forked from auracaster/bumble_mirror
Compare commits
49 Commits
gbg/fix-cl
...
v0.0.134
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
297246fa4c | ||
|
|
52db1cfcc1 | ||
|
|
29f9a79502 | ||
|
|
c86125de4f | ||
|
|
697d5df3f8 | ||
|
|
87aa4f617e | ||
|
|
a8eff737e6 | ||
|
|
4417eb636c | ||
|
|
f4e5e61bbb | ||
|
|
ba7a60025f | ||
|
|
d92b7e9b74 | ||
|
|
b0336adf1c | ||
|
|
691450c7de | ||
|
|
99a0eb21c1 | ||
|
|
ab4859bd94 | ||
|
|
0d70cbde64 | ||
|
|
f41d0682b2 | ||
|
|
062dc1e53d | ||
|
|
662704e551 | ||
|
|
02a474c44e | ||
|
|
a1c7aec492 | ||
|
|
6112f00049 | ||
|
|
f56ac14f2c | ||
|
|
a739fc71ce | ||
|
|
b89f9030a0 | ||
|
|
9e5a85bd10 | ||
|
|
b437bd8619 | ||
|
|
a3e4674819 | ||
|
|
5f1d57fcb0 | ||
|
|
ae0b739e4a | ||
|
|
0570b59796 | ||
|
|
22218627f6 | ||
|
|
1c72242264 | ||
|
|
9c133706e6 | ||
|
|
4988a31487 | ||
|
|
e6c062117f | ||
|
|
f2133235d5 | ||
|
|
867e8c13dc | ||
|
|
25ce38c3f5 | ||
|
|
c0810230a6 | ||
|
|
27c46eef9d | ||
|
|
c140876157 | ||
|
|
d743656f09 | ||
|
|
b91d0e24c1 | ||
|
|
eb46f60c87 | ||
|
|
8bbba7c84c | ||
|
|
ee54df2d08 | ||
|
|
6549e53398 | ||
|
|
6e1baf0344 |
16
.github/workflows/python-build-test.yml
vendored
16
.github/workflows/python-build-test.yml
vendored
@@ -14,6 +14,10 @@ jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
@@ -21,18 +25,18 @@ jobs:
|
||||
- name: Get history and tags for SCM versioning to work
|
||||
run: |
|
||||
git fetch --prune --unshallow
|
||||
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development,documentation]"
|
||||
- name: Test with pytest
|
||||
- name: Test
|
||||
run: |
|
||||
pytest
|
||||
invoke test
|
||||
- name: Build
|
||||
run: |
|
||||
inv build
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -3,9 +3,6 @@ build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
*~
|
||||
bumble/__pycache__
|
||||
docs/mkdocs/site
|
||||
tests/__pycache__
|
||||
test-results.xml
|
||||
bumble/transport/__pycache__
|
||||
bumble/profiles/__pycache__
|
||||
__pycache__
|
||||
|
||||
@@ -443,7 +443,10 @@ class ConsoleApp:
|
||||
# Discover all services, characteristics and descriptors
|
||||
self.append_to_output('discovering services...')
|
||||
await self.connected_peer.discover_services()
|
||||
self.append_to_output(f'found {len(self.connected_peer.services)} services, discovering charateristics...')
|
||||
self.append_to_output(
|
||||
f'found {len(self.connected_peer.services)} services,'
|
||||
' discovering characteristics...'
|
||||
)
|
||||
await self.connected_peer.discover_characteristics()
|
||||
self.append_to_output('found characteristics, discovering descriptors...')
|
||||
for service in self.connected_peer.services:
|
||||
@@ -902,7 +905,7 @@ class LogHandler(logging.Handler):
|
||||
def __init__(self, app):
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.setFormatter("[%(asctime)s][%(pathname)s:%(lineno)d][%(levelname)s] %(message)s")
|
||||
self.setFormatter(logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s'))
|
||||
|
||||
def emit(self, record):
|
||||
message = self.format(record)
|
||||
|
||||
@@ -58,6 +58,12 @@ def padded_bytes(buffer, size):
|
||||
return buffer + bytes(padding_size)
|
||||
|
||||
|
||||
def get_dict_key_by_value(dictionary, value):
|
||||
for key, val in dictionary.items():
|
||||
if val == value:
|
||||
return key
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -135,7 +141,7 @@ class UUID:
|
||||
else:
|
||||
uuid_str = uuid_str_or_int
|
||||
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
|
||||
raise ValueError('invalid UUID format')
|
||||
raise ValueError(f"invalid UUID format: {uuid_str}")
|
||||
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
|
||||
self.name = name
|
||||
|
||||
|
||||
247
bumble/device.py
247
bumble/device.py
@@ -294,8 +294,8 @@ class Peer:
|
||||
async def discover_attributes(self):
|
||||
return await self.gatt_client.discover_attributes()
|
||||
|
||||
async def subscribe(self, characteristic, subscriber=None):
|
||||
return await self.gatt_client.subscribe(characteristic, subscriber)
|
||||
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
|
||||
return await self.gatt_client.subscribe(characteristic, subscriber, prefer_notify)
|
||||
|
||||
async def unsubscribe(self, characteristic, subscriber=None):
|
||||
return await self.gatt_client.unsubscribe(characteristic, subscriber)
|
||||
@@ -394,6 +394,7 @@ class Connection(CompositeEventEmitter):
|
||||
device,
|
||||
handle,
|
||||
transport,
|
||||
self_address,
|
||||
peer_address,
|
||||
peer_resolvable_address,
|
||||
role,
|
||||
@@ -404,6 +405,7 @@ class Connection(CompositeEventEmitter):
|
||||
self.device = device
|
||||
self.handle = handle
|
||||
self.transport = transport
|
||||
self.self_address = self_address
|
||||
self.peer_address = peer_address
|
||||
self.peer_resolvable_address = peer_resolvable_address
|
||||
self.peer_name = None # Classic only
|
||||
@@ -411,12 +413,36 @@ class Connection(CompositeEventEmitter):
|
||||
self.parameters = parameters
|
||||
self.encryption = 0
|
||||
self.authenticated = False
|
||||
self.sc = False
|
||||
self.link_key_type = None
|
||||
self.authenticating = False
|
||||
self.phy = phy
|
||||
self.att_mtu = ATT_DEFAULT_MTU
|
||||
self.data_length = DEVICE_DEFAULT_DATA_LENGTH
|
||||
self.gatt_client = None # Per-connection client
|
||||
self.gatt_server = device.gatt_server # By default, use the device's shared server
|
||||
|
||||
# [Classic only]
|
||||
@classmethod
|
||||
def incomplete(cls, device, peer_address):
|
||||
"""
|
||||
Instantiate an incomplete connection (ie. one waiting for a HCI Connection Complete event).
|
||||
Once received it shall be completed using the `.complete` method.
|
||||
"""
|
||||
return cls(device, None, BT_BR_EDR_TRANSPORT, device.public_address, peer_address, None, None, None, None)
|
||||
|
||||
# [Classic only]
|
||||
def complete(self, handle, peer_resolvable_address, role, parameters):
|
||||
"""
|
||||
Finish an incomplete connection upon completion.
|
||||
"""
|
||||
assert self.handle is None
|
||||
assert self.transport == BT_BR_EDR_TRANSPORT
|
||||
self.handle = handle
|
||||
self.peer_resolvable_address = peer_resolvable_address
|
||||
self.role = role
|
||||
self.parameters = parameters
|
||||
|
||||
@property
|
||||
def role_name(self):
|
||||
return 'CENTRAL' if self.role == BT_CENTRAL_ROLE else 'PERIPHERAL'
|
||||
@@ -538,6 +564,7 @@ class DeviceConfiguration:
|
||||
)
|
||||
self.irk = bytes(16) # This really must be changed for any level of security
|
||||
self.keystore = None
|
||||
self.gatt_services = []
|
||||
|
||||
def load_from_dict(self, config):
|
||||
# Load simple properties
|
||||
@@ -554,6 +581,7 @@ class DeviceConfiguration:
|
||||
self.classic_accept_any = config.get('classic_accept_any', self.classic_accept_any)
|
||||
self.connectable = config.get('connectable', self.connectable)
|
||||
self.discoverable = config.get('discoverable', self.discoverable)
|
||||
self.gatt_services = config.get('gatt_services', self.gatt_services)
|
||||
|
||||
# Load or synthesize an IRK
|
||||
irk = config.get('irk')
|
||||
@@ -587,7 +615,7 @@ def with_connection_from_handle(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, connection_handle, *args, **kwargs):
|
||||
if (connection := self.lookup_connection(connection_handle)) is None:
|
||||
raise ValueError('no connection for handle')
|
||||
raise ValueError(f"no connection for handle: 0x{connection_handle:04x}")
|
||||
return function(self, connection, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@@ -596,6 +624,8 @@ def with_connection_from_handle(function):
|
||||
def with_connection_from_address(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, address, *args, **kwargs):
|
||||
if (connection := self.pending_connections.get(address, False)):
|
||||
return function(self, connection, *args, **kwargs)
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address == address:
|
||||
return function(self, connection, *args, **kwargs)
|
||||
@@ -607,6 +637,8 @@ def with_connection_from_address(function):
|
||||
def try_with_connection_from_address(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, address, *args, **kwargs):
|
||||
if (connection := self.pending_connections.get(address, False)):
|
||||
return function(self, connection, address, *args, **kwargs)
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address == address:
|
||||
return function(self, connection, address, *args, **kwargs)
|
||||
@@ -694,11 +726,16 @@ class Device(CompositeEventEmitter):
|
||||
self.le_connecting = False
|
||||
self.disconnecting = False
|
||||
self.connections = {} # Connections, by connection handle
|
||||
self.pending_connections = {} # Connections, by BD address (BR/EDR only)
|
||||
self.classic_enabled = False
|
||||
self.inquiry_response = None
|
||||
self.address_resolver = None
|
||||
self.classic_pending_accepts = {Address.ANY: []} # Futures, by BD address OR [Futures] for Address.ANY
|
||||
|
||||
# Own address type cache
|
||||
self.advertising_own_address_type = None
|
||||
self.connect_own_address_type = None
|
||||
|
||||
# Use the initial config or a default
|
||||
self.public_address = Address('00:00:00:00:00:00')
|
||||
if config is None:
|
||||
@@ -720,6 +757,26 @@ class Device(CompositeEventEmitter):
|
||||
self.connectable = config.connectable
|
||||
self.classic_accept_any = config.classic_accept_any
|
||||
|
||||
for service in config.gatt_services:
|
||||
characteristics = []
|
||||
for characteristic in service.get("characteristics", []):
|
||||
descriptors = []
|
||||
for descriptor in characteristic.get("descriptors", []):
|
||||
new_descriptor = Descriptor(
|
||||
descriptor_type=descriptor["descriptor_type"],
|
||||
permissions=descriptor["permission"],
|
||||
)
|
||||
descriptors.append(new_descriptor)
|
||||
new_characteristic = Characteristic(
|
||||
uuid=characteristic["uuid"],
|
||||
properties=characteristic["properties"],
|
||||
permissions=int(characteristic["permissions"], 0),
|
||||
descriptors=descriptors,
|
||||
)
|
||||
characteristics.append(new_characteristic)
|
||||
new_service = Service(uuid=service["uuid"], characteristics=characteristics)
|
||||
self.gatt_server.add_service(new_service)
|
||||
|
||||
# If a name is passed, override the name from the config
|
||||
if name:
|
||||
self.name = name
|
||||
@@ -731,8 +788,7 @@ class Device(CompositeEventEmitter):
|
||||
self.random_address = address
|
||||
|
||||
# Setup SMP
|
||||
# TODO: allow using a public address
|
||||
self.smp_manager = smp.Manager(self, self.random_address)
|
||||
self.smp_manager = smp.Manager(self)
|
||||
self.l2cap_channel_manager.register_fixed_channel(
|
||||
smp.SMP_CID, self.on_smp_pdu)
|
||||
self.l2cap_channel_manager.register_fixed_channel(
|
||||
@@ -741,9 +797,7 @@ class Device(CompositeEventEmitter):
|
||||
# Register the SDP server with the L2CAP Channel Manager
|
||||
self.sdp_server.register(self.l2cap_channel_manager)
|
||||
|
||||
# Add a GAP Service if requested
|
||||
if generic_access_service:
|
||||
self.gatt_server.add_service(GenericAccessService(self.name))
|
||||
self.add_default_services(generic_access_service)
|
||||
self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu)
|
||||
|
||||
# Forward some events
|
||||
@@ -791,7 +845,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False):
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address.get_bytes() == bd_addr.get_bytes():
|
||||
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
|
||||
if check_address_type and connection.peer_address.address_type != bd_addr.address_type:
|
||||
continue
|
||||
if transport is None or connection.transport == transport:
|
||||
@@ -928,7 +982,7 @@ class Device(CompositeEventEmitter):
|
||||
self,
|
||||
advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
target=None,
|
||||
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
auto_restart=False
|
||||
):
|
||||
# If we're advertising, stop first
|
||||
@@ -975,9 +1029,10 @@ class Device(CompositeEventEmitter):
|
||||
advertising_enable = 1
|
||||
), check_result=True)
|
||||
|
||||
self.auto_restart_advertising = auto_restart
|
||||
self.advertising_type = advertising_type
|
||||
self.advertising = True
|
||||
self.advertising_own_address_type = own_address_type
|
||||
self.auto_restart_advertising = auto_restart
|
||||
self.advertising_type = advertising_type
|
||||
self.advertising = True
|
||||
|
||||
async def stop_advertising(self):
|
||||
# Disable advertising
|
||||
@@ -986,9 +1041,10 @@ class Device(CompositeEventEmitter):
|
||||
advertising_enable = 0
|
||||
), check_result=True)
|
||||
|
||||
self.advertising = False
|
||||
self.advertising_type = None
|
||||
self.auto_restart_advertising = False
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
self.advertising_type = None
|
||||
self.auto_restart_advertising = False
|
||||
|
||||
@property
|
||||
def is_advertising(self):
|
||||
@@ -1000,7 +1056,7 @@ class Device(CompositeEventEmitter):
|
||||
active=True,
|
||||
scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
|
||||
scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
|
||||
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
filter_duplicates=False,
|
||||
scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY)
|
||||
):
|
||||
@@ -1181,7 +1237,7 @@ class Device(CompositeEventEmitter):
|
||||
peer_address,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
connection_parameters_preferences=None,
|
||||
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT
|
||||
):
|
||||
'''
|
||||
@@ -1251,6 +1307,8 @@ class Device(CompositeEventEmitter):
|
||||
HCI_LE_CODED_PHY: ConnectionParametersPreferences.default
|
||||
}
|
||||
|
||||
self.connect_own_address_type = own_address_type
|
||||
|
||||
if self.host.supports_command(HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND):
|
||||
# Only keep supported PHYs
|
||||
phys = sorted(list(set(filter(self.supports_le_phy, connection_parameters_preferences.keys()))))
|
||||
@@ -1314,6 +1372,9 @@ class Device(CompositeEventEmitter):
|
||||
max_ce_length = int(prefs.max_ce_length / 0.625),
|
||||
))
|
||||
else:
|
||||
# Save pending connection
|
||||
self.pending_connections[peer_address] = Connection.incomplete(self, peer_address)
|
||||
|
||||
# TODO: allow passing other settings
|
||||
result = await self.send_command(HCI_Create_Connection_Command(
|
||||
bd_addr = peer_address,
|
||||
@@ -1350,6 +1411,9 @@ class Device(CompositeEventEmitter):
|
||||
self.remove_listener('connection_failure', on_connection_failure)
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
self.le_connecting = False
|
||||
self.connect_own_address_type = None
|
||||
else:
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
@@ -1363,7 +1427,7 @@ class Device(CompositeEventEmitter):
|
||||
Notes:
|
||||
* A `connect` to the same peer will also complete this call.
|
||||
* The `timeout` parameter is only handled while waiting for the connection request,
|
||||
once received and accepeted, the controller shall issue a connection complete event.
|
||||
once received and accepted, the controller shall issue a connection complete event.
|
||||
'''
|
||||
|
||||
if type(peer_address) is str:
|
||||
@@ -1419,6 +1483,9 @@ class Device(CompositeEventEmitter):
|
||||
self.on('connection', on_connection)
|
||||
self.on('connection_failure', on_connection_failure)
|
||||
|
||||
# Save pending connection
|
||||
self.pending_connections[peer_address] = Connection.incomplete(self, peer_address)
|
||||
|
||||
try:
|
||||
# Accept connection request
|
||||
await self.send_command(HCI_Accept_Connection_Request_Command(
|
||||
@@ -1432,6 +1499,7 @@ class Device(CompositeEventEmitter):
|
||||
finally:
|
||||
self.remove_listener('connection', on_connection)
|
||||
self.remove_listener('connection_failure', on_connection_failure)
|
||||
self.pending_connections.pop(peer_address, None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_as_gatt(self, peer_address):
|
||||
@@ -1675,9 +1743,13 @@ class Device(CompositeEventEmitter):
|
||||
logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}')
|
||||
raise HCI_StatusError(result)
|
||||
|
||||
# Save in connection we are trying to authenticate
|
||||
connection.authenticating = True
|
||||
|
||||
# Wait for the authentication to complete
|
||||
await pending_authentication
|
||||
finally:
|
||||
connection.authenticating = False
|
||||
connection.remove_listener('connection_authentication', on_authentication)
|
||||
connection.remove_listener('connection_authentication_failure', on_authentication_failure)
|
||||
|
||||
@@ -1754,28 +1826,18 @@ class Device(CompositeEventEmitter):
|
||||
# Set up event handlers
|
||||
pending_name = asyncio.get_running_loop().create_future()
|
||||
|
||||
if type(remote) == Address:
|
||||
peer_address = remote
|
||||
handler = self.on(
|
||||
'remote_name',
|
||||
lambda address, remote_name:
|
||||
pending_name.set_result(remote_name) if address == remote else None
|
||||
)
|
||||
failure_handler = self.on(
|
||||
'remote_name_failure',
|
||||
lambda address, error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None
|
||||
)
|
||||
else:
|
||||
peer_address = remote.peer_address
|
||||
handler = remote.on(
|
||||
'remote_name',
|
||||
lambda: pending_name.set_result(remote.peer_name)
|
||||
)
|
||||
failure_handler = remote.on(
|
||||
'remote_name_failure',
|
||||
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
|
||||
)
|
||||
peer_address = remote if type(remote) == Address else remote.peer_address
|
||||
|
||||
handler = self.on(
|
||||
'remote_name',
|
||||
lambda address, remote_name:
|
||||
pending_name.set_result(remote_name) if address == peer_address else None
|
||||
)
|
||||
failure_handler = self.on(
|
||||
'remote_name_failure',
|
||||
lambda address, error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == peer_address else None
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.send_command(
|
||||
@@ -1794,12 +1856,8 @@ class Device(CompositeEventEmitter):
|
||||
# Wait for the result
|
||||
return await pending_name
|
||||
finally:
|
||||
if type(remote) == Address:
|
||||
self.remove_listener('remote_name', handler)
|
||||
self.remove_listener('remote_name_failure', failure_handler)
|
||||
else:
|
||||
remote.remove_listener('remote_name', handler)
|
||||
remote.remove_listener('remote_name_failure', failure_handler)
|
||||
self.remove_listener('remote_name', handler)
|
||||
self.remove_listener('remote_name_failure', failure_handler)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -1817,12 +1875,20 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
asyncio.create_task(store_keys())
|
||||
|
||||
if (connection := self.find_connection_by_bd_addr(bd_addr, transport=BT_BR_EDR_TRANSPORT)):
|
||||
connection.link_key_type = key_type
|
||||
|
||||
def add_service(self, service):
|
||||
self.gatt_server.add_service(service)
|
||||
|
||||
def add_services(self, services):
|
||||
self.gatt_server.add_services(services)
|
||||
|
||||
def add_default_services(self, generic_access_service=True):
|
||||
# Add a GAP Service if requested
|
||||
if generic_access_service:
|
||||
self.gatt_server.add_service(GenericAccessService(self.name))
|
||||
|
||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||
|
||||
@@ -1843,20 +1909,12 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
if transport == BT_BR_EDR_TRANSPORT:
|
||||
# Create a new connection
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
transport,
|
||||
peer_address,
|
||||
peer_resolvable_address,
|
||||
role,
|
||||
connection_parameters,
|
||||
phy=None
|
||||
)
|
||||
connection: Connection = self.pending_connections.pop(peer_address)
|
||||
connection.complete(connection_handle, peer_resolvable_address, role, connection_parameters)
|
||||
self.connections[connection_handle] = connection
|
||||
|
||||
# We may have an accept ongoing waiting for a connection request for `peer_address`.
|
||||
# Typicaly happen when using `connect` to the same `peer_address` we are waiting with
|
||||
# Typically happen when using `connect` to the same `peer_address` we are waiting with
|
||||
# an `accept` for.
|
||||
# In this case, set the completed `connection` to the `accept` future result.
|
||||
if peer_address in self.classic_pending_accepts:
|
||||
@@ -1875,8 +1933,17 @@ class Device(CompositeEventEmitter):
|
||||
peer_resolvable_address = peer_address
|
||||
peer_address = resolved_address
|
||||
|
||||
# Guess which own address type is used for this connection.
|
||||
# This logic is somewhat correct but may need to be improved
|
||||
# when multiple advertising are run simultaneously.
|
||||
if self.connect_own_address_type is not None:
|
||||
own_address_type = self.connect_own_address_type
|
||||
else:
|
||||
own_address_type = self.advertising_own_address_type
|
||||
|
||||
# We are no longer advertising
|
||||
self.advertising = False
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
|
||||
# Create and notify of the new connection asynchronously
|
||||
async def new_connection():
|
||||
@@ -1890,11 +1957,16 @@ class Device(CompositeEventEmitter):
|
||||
else:
|
||||
phy = ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY)
|
||||
|
||||
self_address = self.random_address
|
||||
if own_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC):
|
||||
self_address = self.public_address
|
||||
|
||||
# Create a new connection
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
transport,
|
||||
self_address,
|
||||
peer_address,
|
||||
peer_resolvable_address,
|
||||
role,
|
||||
@@ -1914,7 +1986,8 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# For directed advertising, this means a timeout
|
||||
if transport == BT_LE_TRANSPORT and self.advertising and self.advertising_type.is_directed:
|
||||
self.advertising = False
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
|
||||
# Notify listeners
|
||||
error = ConnectionError(
|
||||
@@ -1943,6 +2016,9 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# device configuration is set to accept any incoming connection
|
||||
elif self.classic_accept_any:
|
||||
# Save pending connection
|
||||
self.pending_connections[bd_addr] = Connection.incomplete(self, bd_addr)
|
||||
|
||||
self.host.send_command_sync(
|
||||
HCI_Accept_Connection_Request_Command(
|
||||
bd_addr = bd_addr,
|
||||
@@ -2016,6 +2092,17 @@ class Device(CompositeEventEmitter):
|
||||
logger.debug(f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}')
|
||||
connection.emit('connection_authentication_failure', error)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_ssp_complete(self, connection):
|
||||
# On Secure Simple Pairing complete, in case:
|
||||
# - Connection isn't already authenticated
|
||||
# - AND we are not the initiator of the authentication
|
||||
# We must trigger authentication to known if we are truly authenticated
|
||||
if not connection.authenticating and not connection.authenticated:
|
||||
logger.debug(f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address}')
|
||||
asyncio.create_task(connection.authenticate())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
@@ -2067,13 +2154,13 @@ class Device(CompositeEventEmitter):
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
|
||||
can_confirm = pairing_config.delegate.io_capability not in {
|
||||
can_compare = pairing_config.delegate.io_capability not in {
|
||||
smp.SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
smp.SMP_DISPLAY_ONLY_IO_CAPABILITY
|
||||
}
|
||||
|
||||
# Respond
|
||||
if can_confirm and pairing_config.delegate:
|
||||
if can_compare:
|
||||
async def compare_numbers():
|
||||
numbers_match = await pairing_config.delegate.compare_numbers(code, digits=6)
|
||||
if numbers_match:
|
||||
@@ -2087,9 +2174,18 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
asyncio.create_task(compare_numbers())
|
||||
else:
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
|
||||
)
|
||||
async def confirm():
|
||||
confirm = await pairing_config.delegate.confirm()
|
||||
if confirm:
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
|
||||
)
|
||||
else:
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command(bd_addr=connection.peer_address)
|
||||
)
|
||||
|
||||
asyncio.create_task(confirm())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -2104,7 +2200,7 @@ class Device(CompositeEventEmitter):
|
||||
}
|
||||
|
||||
# Respond
|
||||
if can_input and pairing_config.delegate:
|
||||
if can_input:
|
||||
async def get_number():
|
||||
number = await pairing_config.delegate.get_number()
|
||||
if number is not None:
|
||||
@@ -2124,6 +2220,15 @@ class Device(CompositeEventEmitter):
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command(bd_addr=connection.peer_address)
|
||||
)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_passkey_notification(self, connection, passkey):
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
|
||||
asyncio.create_task(pairing_config.delegate.display_number(passkey))
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@try_with_connection_from_address
|
||||
@@ -2134,8 +2239,7 @@ class Device(CompositeEventEmitter):
|
||||
if connection:
|
||||
connection.peer_name = remote_name
|
||||
connection.emit('remote_name')
|
||||
else:
|
||||
self.emit('remote_name', address, remote_name)
|
||||
self.emit('remote_name', address, remote_name)
|
||||
except UnicodeDecodeError as error:
|
||||
logger.warning('peer name is not valid UTF-8')
|
||||
if connection:
|
||||
@@ -2149,8 +2253,7 @@ class Device(CompositeEventEmitter):
|
||||
def on_remote_name_failure(self, connection, address, error):
|
||||
if connection:
|
||||
connection.emit('remote_name_failure', error)
|
||||
else:
|
||||
self.emit('remote_name_failure', address, error)
|
||||
self.emit('remote_name_failure', address, error)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
@@ -2216,7 +2319,9 @@ class Device(CompositeEventEmitter):
|
||||
connection.emit('pairing_start')
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing(self, connection, keys):
|
||||
def on_pairing(self, connection, keys, sc):
|
||||
connection.sc = sc
|
||||
connection.authenticated = True
|
||||
connection.emit('pairing', keys)
|
||||
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import types
|
||||
@@ -151,6 +152,14 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
|
||||
# Battery Service
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
|
||||
|
||||
# ASHA Service
|
||||
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
|
||||
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
|
||||
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
|
||||
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
|
||||
|
||||
# Misc
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
|
||||
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
|
||||
@@ -187,7 +196,7 @@ class Service(Attribute):
|
||||
See Vol 3, Part G - 3.1 SERVICE DEFINITION
|
||||
'''
|
||||
|
||||
def __init__(self, uuid, characteristics, primary=True):
|
||||
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if type(uuid) is str:
|
||||
uuid = UUID(uuid)
|
||||
@@ -202,6 +211,14 @@ class Service(Attribute):
|
||||
self.characteristics = characteristics[:]
|
||||
self.primary = primary
|
||||
|
||||
def get_advertising_data(self):
|
||||
"""
|
||||
Get Service specific advertising data
|
||||
Defined by each Service, default value is empty
|
||||
:return Service data for advertising
|
||||
"""
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}'
|
||||
|
||||
@@ -256,10 +273,21 @@ class Characteristic(Attribute):
|
||||
if properties & p
|
||||
])
|
||||
|
||||
def __init__(self, uuid, properties, permissions, value = b'', descriptors = []):
|
||||
@staticmethod
|
||||
def string_to_properties(properties_str: str):
|
||||
return functools.reduce(
|
||||
lambda x, y: x | get_dict_key_by_value(Characteristic.PROPERTY_NAMES, y),
|
||||
properties_str.split(","),
|
||||
0,
|
||||
)
|
||||
|
||||
def __init__(self, uuid, properties, permissions, value = b'', descriptors: list[Descriptor] = []):
|
||||
super().__init__(uuid, permissions, value)
|
||||
self.uuid = self.type
|
||||
self.properties = properties
|
||||
if type(properties) is str:
|
||||
self.properties = Characteristic.string_to_properties(properties)
|
||||
else:
|
||||
self.properties = properties
|
||||
self.descriptors = descriptors
|
||||
|
||||
def get_descriptor(self, descriptor_type):
|
||||
@@ -271,6 +299,24 @@ class Characteristic(Attribute):
|
||||
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicDeclaration(Attribute):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION
|
||||
'''
|
||||
def __init__(self, characteristic, value_handle):
|
||||
declaration_bytes = struct.pack(
|
||||
'<BH',
|
||||
characteristic.properties,
|
||||
value_handle
|
||||
) + characteristic.uuid.to_pdu_bytes()
|
||||
super().__init__(GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes)
|
||||
self.value_handle = value_handle
|
||||
self.characteristic = characteristic
|
||||
|
||||
def __str__(self):
|
||||
return f'CharacteristicDeclaration(handle=0x{self.handle:04X}, value_handle=0x{self.value_handle:04X}, uuid={self.characteristic.uuid}, properties={Characteristic.properties_as_string(self.characteristic.properties)})'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicValue:
|
||||
'''
|
||||
|
||||
@@ -26,20 +26,17 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from colors import color
|
||||
|
||||
from .core import ProtocolError, TimeoutError
|
||||
from .hci import *
|
||||
from .att import *
|
||||
from .gatt import (
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_REQUEST_TIMEOUT,
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits
|
||||
)
|
||||
from .core import InvalidStateError, ProtocolError, TimeoutError
|
||||
from .gatt import (GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_REQUEST_TIMEOUT,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, Characteristic,
|
||||
ClientCharacteristicConfigurationBits)
|
||||
from .hci import *
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -115,7 +112,7 @@ class CharacteristicProxy(AttributeProxy):
|
||||
async def discover_descriptors(self):
|
||||
return await self.client.discover_descriptors(self)
|
||||
|
||||
async def subscribe(self, subscriber=None):
|
||||
async def subscribe(self, subscriber=None, prefer_notify=True):
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
# We already have a proxy subscriber
|
||||
@@ -129,7 +126,7 @@ class CharacteristicProxy(AttributeProxy):
|
||||
self.subscribers[subscriber] = on_change
|
||||
subscriber = on_change
|
||||
|
||||
return await self.client.subscribe(self, subscriber)
|
||||
return await self.client.subscribe(self, subscriber, prefer_notify)
|
||||
|
||||
async def unsubscribe(self, subscriber=None):
|
||||
if subscriber in self.subscribers:
|
||||
@@ -547,7 +544,7 @@ class Client:
|
||||
|
||||
return attributes
|
||||
|
||||
async def subscribe(self, characteristic, subscriber=None):
|
||||
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
|
||||
# If we haven't already discovered the descriptors for this characteristic, do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
await self.discover_descriptors(characteristic)
|
||||
@@ -558,23 +555,32 @@ class Client:
|
||||
logger.warning('subscribing to characteristic with no CCCD descriptor')
|
||||
return
|
||||
|
||||
# Set the subscription bits and select the subscriber set
|
||||
bits = ClientCharacteristicConfigurationBits.DEFAULT
|
||||
subscriber_sets = []
|
||||
if characteristic.properties & Characteristic.NOTIFY:
|
||||
bits |= ClientCharacteristicConfigurationBits.NOTIFICATION
|
||||
subscriber_sets.append(self.notification_subscribers.setdefault(characteristic.handle, set()))
|
||||
if characteristic.properties & Characteristic.INDICATE:
|
||||
bits |= ClientCharacteristicConfigurationBits.INDICATION
|
||||
subscriber_sets.append(self.indication_subscribers.setdefault(characteristic.handle, set()))
|
||||
if (
|
||||
characteristic.properties & Characteristic.NOTIFY
|
||||
and characteristic.properties & Characteristic.INDICATE
|
||||
):
|
||||
if prefer_notify:
|
||||
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
|
||||
subscribers = self.notification_subscribers
|
||||
else:
|
||||
bits = ClientCharacteristicConfigurationBits.INDICATION
|
||||
subscribers = self.indication_subscribers
|
||||
elif characteristic.properties & Characteristic.NOTIFY:
|
||||
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
|
||||
subscribers = self.notification_subscribers
|
||||
elif characteristic.properties & Characteristic.INDICATE:
|
||||
bits = ClientCharacteristicConfigurationBits.INDICATION
|
||||
subscribers = self.indication_subscribers
|
||||
else:
|
||||
raise InvalidStateError("characteristic is not notify or indicate")
|
||||
|
||||
# Add subscribers to the sets
|
||||
for subscriber_set in subscriber_sets:
|
||||
if subscriber is not None:
|
||||
subscriber_set.add(subscriber)
|
||||
# Add the characteristic as a subscriber, which will result in the characteristic
|
||||
# emitting an 'update' event when a notification or indication is received
|
||||
subscriber_set.add(characteristic)
|
||||
subscriber_set = subscribers.setdefault(characteristic.handle, set())
|
||||
if subscriber is not None:
|
||||
subscriber_set.add(subscriber)
|
||||
# Add the characteristic as a subscriber, which will result in the characteristic
|
||||
# emitting an 'update' event when a notification or indication is received
|
||||
subscriber_set.add(characteristic)
|
||||
|
||||
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Optional
|
||||
from pyee import EventEmitter
|
||||
from colors import color
|
||||
|
||||
@@ -60,12 +61,21 @@ class Server(EventEmitter):
|
||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||
self.pending_confirmations = defaultdict(lambda: None)
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(map(str, self.attributes))
|
||||
|
||||
def send_gatt_pdu(self, connection_handle, pdu):
|
||||
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
||||
|
||||
def next_handle(self):
|
||||
return 1 + len(self.attributes)
|
||||
|
||||
def get_advertising_service_data(self):
|
||||
return {
|
||||
attribute: data for attribute in self.attributes
|
||||
if isinstance(attribute, Service) and (data := attribute.get_advertising_data())
|
||||
}
|
||||
|
||||
def get_attribute(self, handle):
|
||||
attribute = self.attributes_by_handle.get(handle)
|
||||
if attribute:
|
||||
@@ -79,6 +89,63 @@ class Server(EventEmitter):
|
||||
return attribute
|
||||
return None
|
||||
|
||||
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
|
||||
return next(
|
||||
(
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
|
||||
and attribute.uuid == service_uuid
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def get_characteristic_attributes(
|
||||
self, service_uuid: UUID, characteristic_uuid: UUID
|
||||
) -> Optional[Tuple[CharacteristicDeclaration, Characteristic]]:
|
||||
service_handle = self.get_service_attribute(service_uuid)
|
||||
if not service_handle:
|
||||
return None
|
||||
|
||||
return next(
|
||||
(
|
||||
(attribute, self.get_attribute(attribute.characteristic.handle))
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(service_handle.handle, service_handle.end_group_handle + 1),
|
||||
)
|
||||
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
and attribute.characteristic.uuid == characteristic_uuid
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def get_descriptor_attribute(
|
||||
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
|
||||
) -> Optional[Descriptor]:
|
||||
characteristics = self.get_characteristic_attributes(
|
||||
service_uuid, characteristic_uuid
|
||||
)
|
||||
if not characteristics:
|
||||
return None
|
||||
|
||||
(_, characteristic_value) = characteristics
|
||||
|
||||
return next(
|
||||
(
|
||||
attribute
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(
|
||||
characteristic_value.handle + 1,
|
||||
characteristic_value.end_group_handle + 1,
|
||||
),
|
||||
)
|
||||
if attribute.type == descriptor_uuid
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def add_attribute(self, attribute):
|
||||
# Assign a handle to this attribute
|
||||
attribute.handle = self.next_handle()
|
||||
@@ -87,7 +154,7 @@ class Server(EventEmitter):
|
||||
# Add this attribute to the list
|
||||
self.attributes.append(attribute)
|
||||
|
||||
def add_service(self, service):
|
||||
def add_service(self, service: Service):
|
||||
# Add the service attribute to the DB
|
||||
self.add_attribute(service)
|
||||
|
||||
@@ -95,16 +162,9 @@ class Server(EventEmitter):
|
||||
|
||||
# Add all characteristics
|
||||
for characteristic in service.characteristics:
|
||||
# Add a Characteristic Declaration (Vol 3, Part G - 3.3.1 Characteristic Declaration)
|
||||
declaration_bytes = struct.pack(
|
||||
'<BH',
|
||||
characteristic.properties,
|
||||
self.next_handle() + 1, # The value will be the next attribute after this declaration
|
||||
) + characteristic.uuid.to_pdu_bytes()
|
||||
characteristic_declaration = Attribute(
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
Attribute.READABLE,
|
||||
declaration_bytes
|
||||
# Add a Characteristic Declaration
|
||||
characteristic_declaration = CharacteristicDeclaration(
|
||||
characteristic, self.next_handle() + 1
|
||||
)
|
||||
self.add_attribute(characteristic_declaration)
|
||||
|
||||
|
||||
@@ -1652,16 +1652,6 @@ class Address:
|
||||
|
||||
ADDRESS_TYPE_SPEC = {'size': 1, 'mapper': lambda x: Address.address_type_name(x)}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def ANY(cls):
|
||||
return cls(b"\xff\xff\xff\xff\xff\xff", cls.PUBLIC_DEVICE_ADDRESS)
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def NIL(cls):
|
||||
return cls(b"\x00\x00\x00\x00\x00\x00", cls.PUBLIC_DEVICE_ADDRESS)
|
||||
|
||||
@staticmethod
|
||||
def address_type_name(address_type):
|
||||
return name_or_number(Address.ADDRESS_TYPE_NAMES, address_type)
|
||||
@@ -1753,9 +1743,36 @@ class Address:
|
||||
'''
|
||||
String representation of the address, MSB first
|
||||
'''
|
||||
return ':'.join([f'{x:02X}' for x in reversed(self.address_bytes)])
|
||||
str = ':'.join([f'{x:02X}' for x in reversed(self.address_bytes)])
|
||||
if not self.is_public:
|
||||
return str
|
||||
return str + '/P'
|
||||
|
||||
|
||||
# Predefined address values
|
||||
Address.NIL = Address(b"\xff\xff\xff\xff\xff\xff", Address.PUBLIC_DEVICE_ADDRESS)
|
||||
Address.ANY = Address(b"\x00\x00\x00\x00\x00\x00", Address.PUBLIC_DEVICE_ADDRESS)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OwnAddressType:
|
||||
PUBLIC = 0
|
||||
RANDOM = 1
|
||||
RESOLVABLE_OR_PUBLIC = 2
|
||||
RESOLVABLE_OR_RANDOM = 3
|
||||
|
||||
TYPE_NAMES = {
|
||||
PUBLIC: 'PUBLIC',
|
||||
RANDOM: 'RANDOM',
|
||||
RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC',
|
||||
RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def type_name(type):
|
||||
return name_or_number(OwnAddressType.TYPE_NAMES, type)
|
||||
|
||||
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HCI_Packet:
|
||||
'''
|
||||
@@ -2848,7 +2865,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
|
||||
('advertising_interval_min', 2),
|
||||
('advertising_interval_max', 2),
|
||||
('advertising_type', {'size': 1, 'mapper': lambda x: HCI_LE_Set_Advertising_Parameters_Command.advertising_type_name(x)}),
|
||||
('own_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('advertising_channel_map', 1),
|
||||
@@ -2927,7 +2944,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command):
|
||||
('le_scan_type', 1),
|
||||
('le_scan_interval', 2),
|
||||
('le_scan_window', 2),
|
||||
('own_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('scanning_filter_policy', 1)
|
||||
])
|
||||
class HCI_LE_Set_Scan_Parameters_Command(HCI_Command):
|
||||
@@ -2961,7 +2978,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command):
|
||||
('initiator_filter_policy', 1),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('own_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('connection_interval_min', 2),
|
||||
('connection_interval_max', 2),
|
||||
('max_latency', 2),
|
||||
@@ -3283,7 +3300,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
|
||||
('primary_advertising_interval_min', 3),
|
||||
('primary_advertising_interval_max', 3),
|
||||
('primary_advertising_channel_map', {'size': 1, 'mapper': lambda x: HCI_LE_Set_Extended_Advertising_Parameters_Command.channel_map_string(x)}),
|
||||
('own_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('advertising_filter_policy', 1),
|
||||
@@ -3687,7 +3704,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
|
||||
initiating_phys_strs = bit_flags_to_strings(self.initiating_phys, HCI_LE_PHY_BIT_NAMES)
|
||||
fields = [
|
||||
('initiator_filter_policy:', self.initiator_filter_policy),
|
||||
('own_address_type: ', Address.address_type_name(self.own_address_type)),
|
||||
('own_address_type: ', OwnAddressType.type_name(self.own_address_type)),
|
||||
('peer_address_type: ', Address.address_type_name(self.peer_address_type)),
|
||||
('peer_address: ', str(self.peer_address)),
|
||||
('initiating_phys: ', ','.join(initiating_phys_strs)),
|
||||
@@ -4855,6 +4872,17 @@ class HCI_Link_Supervision_Timeout_Changed_Event(HCI_Event):
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Event.event([
|
||||
('bd_addr', Address.parse_address),
|
||||
('passkey', 4)
|
||||
])
|
||||
class HCI_User_Passkey_Notification_Event(HCI_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.48 User Passkey Notification Event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Event.event([
|
||||
('bd_addr', Address.parse_address),
|
||||
|
||||
@@ -121,24 +121,28 @@ class Host(EventEmitter):
|
||||
self.hc_acl_data_packet_length = response.return_parameters.hc_acl_data_packet_length
|
||||
self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_acl_data_packets
|
||||
|
||||
logger.debug(
|
||||
f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
|
||||
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
|
||||
)
|
||||
|
||||
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await self.send_command(HCI_LE_Read_Buffer_Size_Command(), check_result=True)
|
||||
self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
|
||||
self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
|
||||
|
||||
if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0:
|
||||
# LE and Classic share the same values
|
||||
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
|
||||
self.hc_total_num_le_acl_data_packets = self.hc_total_num_acl_data_packets
|
||||
logger.debug(
|
||||
f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
|
||||
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
|
||||
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
|
||||
)
|
||||
logger.debug(
|
||||
f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
|
||||
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
|
||||
)
|
||||
if (
|
||||
response.return_parameters.hc_le_acl_data_packet_length == 0 or
|
||||
response.return_parameters.hc_total_num_le_acl_data_packets == 0
|
||||
):
|
||||
# LE and Classic share the same values
|
||||
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
|
||||
self.hc_total_num_le_acl_data_packets = self.hc_total_num_acl_data_packets
|
||||
|
||||
if (
|
||||
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
|
||||
@@ -595,6 +599,9 @@ class Host(EventEmitter):
|
||||
|
||||
def on_hci_simple_pairing_complete_event(self, event):
|
||||
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
|
||||
# Notify the client
|
||||
if event.status == HCI_SUCCESS:
|
||||
self.emit('ssp_complete', event.bd_addr)
|
||||
|
||||
def on_hci_pin_code_request_event(self, event):
|
||||
# For now, just refuse all requests
|
||||
@@ -638,6 +645,9 @@ class Host(EventEmitter):
|
||||
def on_hci_user_passkey_request_event(self, event):
|
||||
self.emit('authentication_user_passkey_request', event.bd_addr)
|
||||
|
||||
def on_hci_user_passkey_notification_event(self, event):
|
||||
self.emit('authentication_user_passkey_notification', event.bd_addr, event.passkey)
|
||||
|
||||
def on_hci_inquiry_complete_event(self, event):
|
||||
self.emit('inquiry_complete')
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
@@ -143,6 +144,10 @@ class KeyStore:
|
||||
async def get_all(self):
|
||||
return []
|
||||
|
||||
async def delete_all(self):
|
||||
all_keys = await self.get_all()
|
||||
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
|
||||
|
||||
async def get_resolving_keys(self):
|
||||
all_keys = await self.get_all()
|
||||
resolving_keys = []
|
||||
@@ -259,6 +264,13 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
|
||||
|
||||
async def delete_all(self):
|
||||
db = await self.load()
|
||||
|
||||
db.pop(self.namespace, None)
|
||||
|
||||
await self.save(db)
|
||||
|
||||
async def get(self, name):
|
||||
db = await self.load()
|
||||
|
||||
|
||||
@@ -1224,7 +1224,7 @@ class ChannelManager:
|
||||
self._host.remove_listener('disconnection', self.on_disconnection)
|
||||
self._host = host
|
||||
if host is not None:
|
||||
host.add_listener('disconnection', self.on_disconnection)
|
||||
host.on('disconnection', self.on_disconnection)
|
||||
|
||||
def find_channel(self, connection_handle, cid):
|
||||
if connection_channels := self.channels.get(connection_handle):
|
||||
|
||||
141
bumble/profiles/asha_service.py
Normal file
141
bumble/profiles/asha_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
import logging
|
||||
from ..core import AdvertisingData
|
||||
from ..gatt import (
|
||||
GATT_ASHA_SERVICE,
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
PackedCharacteristicAdapter
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AshaService(TemplateService):
|
||||
UUID = GATT_ASHA_SERVICE
|
||||
OPCODE_START = 1
|
||||
OPCODE_STOP = 2
|
||||
OPCODE_STATUS = 3
|
||||
PROTOCOL_VERSION = 0x01
|
||||
RESERVED_FOR_FUTURE_USE = [00, 00]
|
||||
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
|
||||
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
|
||||
RENDER_DELAY = [00, 00]
|
||||
|
||||
def __init__(self, capability: int, hisyncid: [int]):
|
||||
self.hisyncid = hisyncid
|
||||
self.capability = capability # Device Capabilities [Left, Monaural]
|
||||
|
||||
# Handler for volume control
|
||||
def on_volume_write(connection, value):
|
||||
logger.info(f'--- VOLUME Write:{value[0]}')
|
||||
|
||||
# Handler for audio control commands
|
||||
def on_audio_control_point_write(connection, value):
|
||||
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
if opcode == AshaService.OPCODE_START:
|
||||
# Start
|
||||
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
|
||||
logger.info(
|
||||
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}')
|
||||
elif opcode == AshaService.OPCODE_STOP:
|
||||
logger.info('### STOP')
|
||||
elif opcode == AshaService.OPCODE_STATUS:
|
||||
logger.info(f'### STATUS: connected={value[1]}')
|
||||
|
||||
# TODO Respond with a status
|
||||
# asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
|
||||
|
||||
self.read_only_properties_characteristic = Characteristic(
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
Characteristic.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes([
|
||||
AshaService.PROTOCOL_VERSION, # Version
|
||||
self.capability,
|
||||
]) +
|
||||
bytes(self.hisyncid) +
|
||||
bytes(AshaService.FEATURE_MAP) +
|
||||
bytes(AshaService.RENDER_DELAY) +
|
||||
bytes(AshaService.RESERVED_FOR_FUTURE_USE) +
|
||||
bytes(AshaService.SUPPORTED_CODEC_ID)
|
||||
)
|
||||
|
||||
self.audio_control_point_characteristic = Characteristic(
|
||||
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=on_audio_control_point_write)
|
||||
)
|
||||
self.audio_status_characteristic = Characteristic(
|
||||
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
Characteristic.READ | Characteristic.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([0])
|
||||
)
|
||||
self.volume_characteristic = Characteristic(
|
||||
GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
Characteristic.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=on_volume_write)
|
||||
)
|
||||
|
||||
# TODO add real psm value
|
||||
self.psm = 0x0080
|
||||
# self.psm = device.register_l2cap_channel_server(0, on_coc, 8)
|
||||
self.le_psm_out_characteristic = Characteristic(
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
Characteristic.READ,
|
||||
Characteristic.READABLE,
|
||||
struct.pack('<H', self.psm)
|
||||
)
|
||||
|
||||
characteristics = [self.read_only_properties_characteristic,
|
||||
self.audio_control_point_characteristic,
|
||||
self.audio_status_characteristic,
|
||||
self.volume_characteristic,
|
||||
self.le_psm_out_characteristic]
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
def get_advertising_data(self):
|
||||
# Advertisement only uses 4 least significant bytes of the HiSyncId.
|
||||
return bytes(
|
||||
AdvertisingData([
|
||||
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(GATT_ASHA_SERVICE)),
|
||||
(AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes(GATT_ASHA_SERVICE) + bytes([
|
||||
AshaService.PROTOCOL_VERSION,
|
||||
self.capability,
|
||||
]) + bytes(self.hisyncid[:4]))
|
||||
])
|
||||
)
|
||||
|
||||
@@ -636,8 +636,8 @@ class Multiplexer(EventEmitter):
|
||||
if self.open_result:
|
||||
self.open_result.set_exception(ConnectionError(
|
||||
ConnectionError.CONNECTION_REFUSED,
|
||||
self.l2cap_channel.connection.peer_address,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
self.l2cap_channel.connection.peer_address,
|
||||
'rfcomm'
|
||||
))
|
||||
else:
|
||||
|
||||
@@ -477,6 +477,9 @@ class PairingDelegate:
|
||||
async def accept(self):
|
||||
return True
|
||||
|
||||
async def confirm(self):
|
||||
return True
|
||||
|
||||
async def compare_numbers(self, number, digits=6):
|
||||
return True
|
||||
|
||||
@@ -637,15 +640,16 @@ class Session:
|
||||
self.oob = False
|
||||
|
||||
# Set up addresses
|
||||
self_address = connection.self_address
|
||||
peer_address = connection.peer_resolvable_address or connection.peer_address
|
||||
if self.is_initiator:
|
||||
self.ia = bytes(manager.address)
|
||||
self.iat = 1 if manager.address.is_random else 0
|
||||
self.ia = bytes(self_address)
|
||||
self.iat = 1 if self_address.is_random else 0
|
||||
self.ra = bytes(peer_address)
|
||||
self.rat = 1 if peer_address.is_random else 0
|
||||
else:
|
||||
self.ra = bytes(manager.address)
|
||||
self.rat = 1 if manager.address.is_random else 0
|
||||
self.ra = bytes(self_address)
|
||||
self.rat = 1 if self_address.is_random else 0
|
||||
self.ia = bytes(peer_address)
|
||||
self.iat = 1 if peer_address.is_random else 0
|
||||
|
||||
@@ -715,6 +719,21 @@ class Session:
|
||||
return False
|
||||
return True
|
||||
|
||||
def prompt_user_for_confirmation(self, next_steps):
|
||||
async def prompt():
|
||||
logger.debug('ask for confirmation')
|
||||
try:
|
||||
response = await self.pairing_config.delegate.confirm()
|
||||
if response:
|
||||
next_steps()
|
||||
return
|
||||
except Exception as error:
|
||||
logger.warn(f'exception while confirm: {error}')
|
||||
|
||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||
|
||||
asyncio.create_task(prompt())
|
||||
|
||||
def prompt_user_for_numeric_comparison(self, code, next_steps):
|
||||
async def prompt():
|
||||
logger.debug(f'verification code: {code}')
|
||||
@@ -907,8 +926,8 @@ class Session:
|
||||
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
|
||||
)
|
||||
self.send_command(SMP_Identity_Address_Information_Command(
|
||||
addr_type = self.manager.address.address_type,
|
||||
bd_addr = self.manager.address
|
||||
addr_type = self.connection.self_address.address_type,
|
||||
bd_addr = self.connection.self_address
|
||||
))
|
||||
|
||||
# Distribute CSRK
|
||||
@@ -939,8 +958,8 @@ class Session:
|
||||
SMP_Identity_Information_Command(identity_resolving_key=self.manager.device.irk)
|
||||
)
|
||||
self.send_command(SMP_Identity_Address_Information_Command(
|
||||
addr_type = self.manager.address.address_type,
|
||||
bd_addr = self.manager.address
|
||||
addr_type = self.connection.self_address.address_type,
|
||||
bd_addr = self.connection.self_address
|
||||
))
|
||||
|
||||
# Distribute CSRK
|
||||
@@ -1091,7 +1110,7 @@ class Session:
|
||||
self.manager.on_pairing(self, peer_address, keys)
|
||||
|
||||
def on_pairing_failure(self, reason):
|
||||
logger.warn(f'pairing failure ({error_name(reason)})')
|
||||
logger.warning(f'pairing failure ({error_name(reason)})')
|
||||
|
||||
if self.completed:
|
||||
return
|
||||
@@ -1387,12 +1406,12 @@ class Session:
|
||||
# Compute the 6-digit code
|
||||
code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000
|
||||
|
||||
if self.pairing_method == self.NUMERIC_COMPARISON:
|
||||
# Ask for user confirmation
|
||||
self.wait_before_continuing = asyncio.get_running_loop().create_future()
|
||||
self.prompt_user_for_numeric_comparison(code, next_steps)
|
||||
# Ask for user confirmation
|
||||
self.wait_before_continuing = asyncio.get_running_loop().create_future()
|
||||
if self.pairing_method == self.JUST_WORKS:
|
||||
self.prompt_user_for_confirmation(next_steps)
|
||||
else:
|
||||
next_steps()
|
||||
self.prompt_user_for_numeric_comparison(code, next_steps)
|
||||
else:
|
||||
next_steps()
|
||||
|
||||
@@ -1486,10 +1505,9 @@ class Manager(EventEmitter):
|
||||
Implements the Initiator and Responder roles of the Security Manager Protocol
|
||||
'''
|
||||
|
||||
def __init__(self, device, address):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.address = address
|
||||
self.sessions = {}
|
||||
self._ecc_key = None
|
||||
self.pairing_config_factory = lambda connection: PairingConfig()
|
||||
@@ -1565,7 +1583,7 @@ class Manager(EventEmitter):
|
||||
asyncio.create_task(store_keys())
|
||||
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection.handle, keys)
|
||||
self.device.on_pairing(session.connection.handle, keys, session.sc)
|
||||
|
||||
def on_pairing_failure(self, session, reason):
|
||||
self.device.on_pairing_failure(session.connection.handle, reason)
|
||||
|
||||
@@ -414,14 +414,13 @@ async def open_usb_transport(spec):
|
||||
|
||||
device = found.open()
|
||||
|
||||
# Detach the kernel driver if supported and needed
|
||||
# Auto-detach the kernel driver if supported
|
||||
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
|
||||
try:
|
||||
if device.kernelDriverActive(interface):
|
||||
logger.debug("detaching kernel driver")
|
||||
device.detachKernelDriver(interface)
|
||||
except usb1.USBError:
|
||||
pass
|
||||
logger.debug('auto-detaching kernel driver')
|
||||
device.setAutoDetachKernelDriver(True)
|
||||
except usb1.USBError as error:
|
||||
logger.warning(f'unable to auto-detach kernel driver: {error}')
|
||||
|
||||
# Set the configuration if needed
|
||||
try:
|
||||
|
||||
@@ -65,6 +65,7 @@ build =
|
||||
test =
|
||||
pytest >= 6.2
|
||||
pytest-asyncio >= 0.17
|
||||
pytest-html >= 3.2.0
|
||||
coverage >= 6.4
|
||||
development =
|
||||
invoke >= 1.4
|
||||
|
||||
13
tasks.py
13
tasks.py
@@ -52,8 +52,9 @@ build_tasks.add_task(mkdocs, name="mkdocs")
|
||||
test_tasks = Collection()
|
||||
ns.add_collection(test_tasks, name="test")
|
||||
|
||||
@task
|
||||
def test(ctx, filter=None, junit=False, install=False):
|
||||
|
||||
@task(incrementable=["verbose"])
|
||||
def test(ctx, filter=None, junit=False, install=False, html=False, verbose=0):
|
||||
# Install the package before running the tests
|
||||
if install:
|
||||
ctx.run("python -m pip install .[test]")
|
||||
@@ -62,8 +63,12 @@ def test(ctx, filter=None, junit=False, install=False):
|
||||
if junit:
|
||||
args += "--junit-xml test-results.xml"
|
||||
if filter is not None:
|
||||
args += " -k '{}'".format(filter)
|
||||
ctx.run("python -m pytest {} {}".format(os.path.join(ROOT_DIR, "tests"), args))
|
||||
args += f" -k '{filter}'"
|
||||
if html:
|
||||
args += " --html results.html"
|
||||
if verbose > 0:
|
||||
args += f" -{'v' * verbose}"
|
||||
ctx.run(f"python -m pytest {os.path.join(ROOT_DIR, 'tests')} {args}")
|
||||
|
||||
test_tasks.add_task(test, default=True)
|
||||
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from bumble.core import AdvertisingData
|
||||
|
||||
from bumble.core import AdvertisingData, get_dict_key_by_value
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ad_data():
|
||||
@@ -39,6 +38,16 @@ def test_ad_data():
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123]), bytes([234])])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_get_dict_key_by_value():
|
||||
dictionary = {
|
||||
"A": 1,
|
||||
"B": 2
|
||||
}
|
||||
assert get_dict_key_by_value(dictionary, 1) == "A"
|
||||
assert get_dict_key_by_value(dictionary, 2) == "B"
|
||||
assert get_dict_key_by_value(dictionary, 3) is None
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
test_ad_data()
|
||||
@@ -28,7 +28,7 @@ from bumble.hci import (
|
||||
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS,
|
||||
Address, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, HCI_Packet
|
||||
)
|
||||
|
||||
from bumble.gatt import GATT_GENERIC_ACCESS_SERVICE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_DEVICE_NAME_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -182,6 +182,27 @@ async def run_test_device():
|
||||
await test_device_connect_parallel()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_gatt_services_with_gas():
|
||||
device = Device(host=Host(None, None))
|
||||
|
||||
# there should be one service and two chars, therefore 5 attributes
|
||||
assert len(device.gatt_server.attributes) == 5
|
||||
assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE
|
||||
assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC
|
||||
assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_gatt_services_without_gas():
|
||||
device = Device(host=Host(None, None), generic_access_service=False)
|
||||
|
||||
# there should be no services
|
||||
assert len(device.gatt_server.attributes) == 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
|
||||
@@ -28,6 +28,7 @@ from bumble.device import Device, Peer
|
||||
from bumble.host import Host
|
||||
from bumble.gatt import (
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
CharacteristicAdapter,
|
||||
DelegatedCharacteristicAdapter,
|
||||
PackedCharacteristicAdapter,
|
||||
@@ -226,6 +227,37 @@ async def test_characteristic_encoding():
|
||||
assert last_change is None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_attribute_getters():
|
||||
[client, server] = LinkedDevices().devices[:2]
|
||||
|
||||
characteristic_uuid = UUID('FDB159DB-036C-49E3-B3DB-6325AC750806')
|
||||
characteristic = Characteristic(
|
||||
characteristic_uuid,
|
||||
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
|
||||
Characteristic.READABLE | Characteristic.WRITEABLE,
|
||||
bytes([123])
|
||||
)
|
||||
|
||||
service_uuid = UUID('3A657F47-D34F-46B3-B1EC-698E29B6B829')
|
||||
service = Service(service_uuid, [characteristic])
|
||||
server.add_service(service)
|
||||
|
||||
service_attr = server.gatt_server.get_service_attribute(service_uuid)
|
||||
assert service_attr
|
||||
|
||||
(char_decl_attr, char_value_attr) = server.gatt_server.get_characteristic_attributes(service_uuid, characteristic_uuid)
|
||||
assert char_decl_attr and char_value_attr
|
||||
|
||||
desc_attr = server.gatt_server.get_descriptor_attribute(service_uuid, characteristic_uuid, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
|
||||
assert desc_attr
|
||||
|
||||
# assert all handles are in expected order
|
||||
assert service_attr.handle < char_decl_attr.handle < char_value_attr.handle < desc_attr.handle == service_attr.end_group_handle
|
||||
# assert characteristic declarations attribute is followed by characteristic value attribute
|
||||
assert char_decl_attr.handle + 1 == char_value_attr.handle
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_CharacteristicAdapter():
|
||||
# Check that the CharacteristicAdapter base class is transparent
|
||||
@@ -612,14 +644,25 @@ async def test_subscribe_notify():
|
||||
await async_barrier()
|
||||
assert not c2._called
|
||||
|
||||
c3._called = False
|
||||
c3._called_2 = False
|
||||
c3._called_3 = False
|
||||
c3._last_update = None
|
||||
c3._last_update_2 = None
|
||||
c3._last_update_3 = None
|
||||
|
||||
def on_c3_update(value):
|
||||
c3._called = True
|
||||
c3._last_update = value
|
||||
|
||||
def on_c3_update_2(value):
|
||||
def on_c3_update_2(value): # for notify
|
||||
c3._called_2 = True
|
||||
c3._last_update_2 = value
|
||||
|
||||
def on_c3_update_3(value): # for indicate
|
||||
c3._called_3 = True
|
||||
c3._last_update_3 = value
|
||||
|
||||
c3.on('update', on_c3_update)
|
||||
await peer.subscribe(c3, on_c3_update_2)
|
||||
await async_barrier()
|
||||
@@ -629,22 +672,33 @@ async def test_subscribe_notify():
|
||||
assert c3._last_update == characteristic3.value
|
||||
assert c3._called_2
|
||||
assert c3._last_update_2 == characteristic3.value
|
||||
assert not c3._called_3
|
||||
|
||||
c3._called = False
|
||||
c3._called_2 = False
|
||||
c3._called_3 = False
|
||||
await peer.unsubscribe(c3)
|
||||
await peer.subscribe(c3, on_c3_update_3, prefer_notify=False)
|
||||
await async_barrier()
|
||||
characteristic3.value = bytes([1, 2, 3])
|
||||
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await async_barrier()
|
||||
assert c3._called
|
||||
assert c3._last_update == characteristic3.value
|
||||
assert c3._called_2
|
||||
assert c3._last_update_2 == characteristic3.value
|
||||
assert not c3._called_2
|
||||
assert c3._called_3
|
||||
assert c3._last_update_3 == characteristic3.value
|
||||
|
||||
c3._called = False
|
||||
c3._called_2 = False
|
||||
c3._called_3 = False
|
||||
await peer.unsubscribe(c3)
|
||||
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
|
||||
await async_barrier()
|
||||
assert not c3._called
|
||||
assert not c3._called_2
|
||||
assert not c3._called_3
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -683,6 +737,55 @@ async def test_mtu_exchange():
|
||||
assert d2_connection.att_mtu == 50
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_char_property_to_string():
|
||||
# single
|
||||
assert Characteristic.property_name(0x01) == "BROADCAST"
|
||||
assert Characteristic.property_name(Characteristic.BROADCAST) == "BROADCAST"
|
||||
|
||||
# double
|
||||
assert Characteristic.properties_as_string(0x03) == "BROADCAST,READ"
|
||||
assert Characteristic.properties_as_string(Characteristic.BROADCAST | Characteristic.READ) == "BROADCAST,READ"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_char_property_string_to_type():
|
||||
# single
|
||||
assert Characteristic.string_to_properties("BROADCAST") == Characteristic.BROADCAST
|
||||
|
||||
# double
|
||||
assert Characteristic.string_to_properties("BROADCAST,READ") == Characteristic.BROADCAST | Characteristic.READ
|
||||
assert Characteristic.string_to_properties("READ,BROADCAST") == Characteristic.BROADCAST | Characteristic.READ
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_string():
|
||||
[_, server] = LinkedDevices().devices[:2]
|
||||
|
||||
characteristic = Characteristic(
|
||||
'FDB159DB-036C-49E3-B3DB-6325AC750806',
|
||||
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
|
||||
Characteristic.READABLE | Characteristic.WRITEABLE,
|
||||
bytes([123])
|
||||
)
|
||||
|
||||
service = Service(
|
||||
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
|
||||
[characteristic]
|
||||
)
|
||||
server.add_service(service)
|
||||
|
||||
assert str(server.gatt_server) == """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access))
|
||||
CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ)
|
||||
Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ)
|
||||
CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), properties=READ)
|
||||
Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), properties=READ)
|
||||
Service(handle=0x0006, end=0x0009, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
|
||||
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY)
|
||||
Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY)
|
||||
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main():
|
||||
await test_read_write()
|
||||
|
||||
@@ -59,7 +59,7 @@ def test_HCI_LE_Connection_Complete_Event():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_HCI_LE_Advertising_Report_Event():
|
||||
address = Address('00:11:22:33:44:55')
|
||||
address = Address('00:11:22:33:44:55/P')
|
||||
report = HCI_LE_Advertising_Report_Event.Report(
|
||||
HCI_LE_Advertising_Report_Event.Report.FIELDS,
|
||||
event_type = HCI_LE_Advertising_Report_Event.ADV_IND,
|
||||
|
||||
@@ -64,37 +64,37 @@ def test_import():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_app_imports():
|
||||
from bumble.apps.console import main
|
||||
from apps.console import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.controller_info import main
|
||||
from apps.controller_info import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.controllers import main
|
||||
from apps.controllers import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.gatt_dump import main
|
||||
from apps.gatt_dump import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.gg_bridge import main
|
||||
from apps.gg_bridge import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.hci_bridge import main
|
||||
from apps.hci_bridge import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.pair import main
|
||||
from apps.pair import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.scan import main
|
||||
from apps.scan import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.show import main
|
||||
from apps.show import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.unbond import main
|
||||
from apps.unbond import main
|
||||
assert main
|
||||
|
||||
from bumble.apps.usb_probe import main
|
||||
from apps.usb_probe import main
|
||||
assert main
|
||||
|
||||
|
||||
|
||||
@@ -85,8 +85,8 @@ async def setup_connection():
|
||||
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
|
||||
|
||||
# Check the post conditions
|
||||
assert(two_devices.connections[0] is not None)
|
||||
assert(two_devices.connections[1] is not None)
|
||||
assert two_devices.connections[0] is not None
|
||||
assert two_devices.connections[1] is not None
|
||||
|
||||
return two_devices
|
||||
|
||||
@@ -94,31 +94,31 @@ async def setup_connection():
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_helpers():
|
||||
psm = L2CAP_Connection_Request.serialize_psm(0x01)
|
||||
assert(psm == bytes([0x01, 0x00]))
|
||||
assert psm == bytes([0x01, 0x00])
|
||||
|
||||
psm = L2CAP_Connection_Request.serialize_psm(0x1023)
|
||||
assert(psm == bytes([0x23, 0x10]))
|
||||
assert psm == bytes([0x23, 0x10])
|
||||
|
||||
psm = L2CAP_Connection_Request.serialize_psm(0x242311)
|
||||
assert(psm == bytes([0x11, 0x23, 0x24]))
|
||||
assert psm == bytes([0x11, 0x23, 0x24])
|
||||
|
||||
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1)
|
||||
assert(offset == 3)
|
||||
assert(psm == 0x01)
|
||||
assert offset == 3
|
||||
assert psm == 0x01
|
||||
|
||||
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1)
|
||||
assert(offset == 3)
|
||||
assert(psm == 0x1023)
|
||||
assert offset == 3
|
||||
assert psm == 0x1023
|
||||
|
||||
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1)
|
||||
assert(offset == 4)
|
||||
assert(psm == 0x242311)
|
||||
assert offset == 4
|
||||
assert psm == 0x242311
|
||||
|
||||
rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44)
|
||||
brq = bytes(rq)
|
||||
srq = L2CAP_Connection_Request.from_bytes(brq)
|
||||
assert(srq.psm == rq.psm)
|
||||
assert(srq.source_cid == rq.source_cid)
|
||||
assert srq.psm == rq.psm
|
||||
assert srq.source_cid == rq.source_cid
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -170,12 +170,12 @@ async def test_basic_connection():
|
||||
l2cap_channel.on('close', lambda: on_close(0, None))
|
||||
incoming_channel.on('close', lambda: on_close(1, closed_event))
|
||||
await l2cap_channel.disconnect()
|
||||
assert(closed == [True, True])
|
||||
assert closed == [True, True]
|
||||
await closed_event.wait()
|
||||
|
||||
sent_bytes = b''.join(messages)
|
||||
received_bytes = b''.join(received)
|
||||
assert(sent_bytes == received_bytes)
|
||||
assert sent_bytes == received_bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -201,7 +201,7 @@ async def transfer_payload(max_credits, mtu, mps):
|
||||
|
||||
messages = [
|
||||
bytes([1, 2, 3, 4, 5, 6, 7]) * x
|
||||
for x in (3, 10, 100, 500, 789)
|
||||
for x in (3, 10, 100, 789)
|
||||
]
|
||||
for message in messages:
|
||||
l2cap_channel.write(message)
|
||||
@@ -214,14 +214,14 @@ async def transfer_payload(max_credits, mtu, mps):
|
||||
|
||||
sent_bytes = b''.join(messages)
|
||||
received_bytes = b''.join(received)
|
||||
assert(sent_bytes == received_bytes)
|
||||
assert sent_bytes == received_bytes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer():
|
||||
for max_credits in (1, 10, 100, 10000):
|
||||
for mtu in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
|
||||
for mps in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
|
||||
for mtu in (50, 255, 256, 1000):
|
||||
for mps in (50, 255, 256, 1000):
|
||||
# print(max_credits, mtu, mps)
|
||||
await transfer_payload(max_credits, mtu, mps)
|
||||
|
||||
@@ -267,8 +267,8 @@ async def test_bidirectional_transfer():
|
||||
message_bytes = b''.join(messages)
|
||||
client_received_bytes = b''.join(client_received)
|
||||
server_received_bytes = b''.join(server_received)
|
||||
assert(client_received_bytes == message_bytes)
|
||||
assert(server_received_bytes == message_bytes)
|
||||
assert client_received_bytes == message_bytes
|
||||
assert server_received_bytes == message_bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user