From 84f7cad67899fd2d70a64b75d5bf8c0397c2a780 Mon Sep 17 00:00:00 2001 From: zxzxwu <92432172+zxzxwu@users.noreply.github.com> Date: Sun, 18 May 2025 07:49:11 +0000 Subject: [PATCH] Replace Optional[Connection] att parameter type --- bumble/att.py | 18 ++++++-------- bumble/gatt.py | 6 +---- bumble/gatt_server.py | 5 +--- bumble/profiles/aics.py | 8 +++--- bumble/profiles/ascs.py | 2 +- bumble/profiles/asha.py | 4 +-- bumble/profiles/csip.py | 4 +-- bumble/profiles/gatt_service.py | 4 +-- bumble/profiles/hap.py | 43 ++++++++++----------------------- bumble/profiles/mcp.py | 5 +--- bumble/profiles/vcs.py | 6 ++--- bumble/profiles/vocs.py | 15 +++++------- tests/gatt_test.py | 38 ++++++++++++++--------------- 13 files changed, 59 insertions(+), 99 deletions(-) diff --git a/bumble/att.py b/bumble/att.py index 6bac58d9..98a82d98 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -770,27 +770,25 @@ class AttributeValue(Generic[_T]): def __init__( self, read: Union[ - Callable[[Optional[Connection]], _T], - Callable[[Optional[Connection]], Awaitable[_T]], + Callable[[Connection], _T], + Callable[[Connection], Awaitable[_T]], None, ] = None, write: Union[ - Callable[[Optional[Connection], _T], None], - Callable[[Optional[Connection], _T], Awaitable[None]], + Callable[[Connection, _T], None], + Callable[[Connection, _T], Awaitable[None]], None, ] = None, ): self._read = read self._write = write - def read(self, connection: Optional[Connection]) -> Union[_T, Awaitable[_T]]: + def read(self, connection: Connection) -> Union[_T, Awaitable[_T]]: if self._read is None: raise InvalidOperationError('AttributeValue has no read function') return self._read(connection) - def write( - self, connection: Optional[Connection], value: _T - ) -> Union[Awaitable[None], None]: + def write(self, connection: Connection, value: _T) -> Union[Awaitable[None], None]: if self._write is None: raise InvalidOperationError('AttributeValue has no write function') return self._write(connection, value) @@ -871,7 +869,7 @@ class Attribute(utils.EventEmitter, Generic[_T]): def decode_value(self, value: bytes) -> _T: return value # type: ignore - async def read_value(self, connection: Optional[Connection]) -> bytes: + async def read_value(self, connection: Connection) -> bytes: if ( (self.permissions & self.READ_REQUIRES_ENCRYPTION) and connection is not None @@ -913,7 +911,7 @@ class Attribute(utils.EventEmitter, Generic[_T]): return b'' if value is None else self.encode_value(value) - async def write_value(self, connection: Optional[Connection], value: bytes) -> None: + async def write_value(self, connection: Connection, value: bytes) -> None: if ( (self.permissions & self.WRITE_REQUIRES_ENCRYPTION) and connection is not None diff --git a/bumble/gatt.py b/bumble/gatt.py index be75f454..701d8d0e 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -579,11 +579,7 @@ class Descriptor(Attribute): if isinstance(self.value, bytes): value_str = self.value.hex() elif isinstance(self.value, CharacteristicValue): - value = self.value.read(None) - if isinstance(value, bytes): - value_str = value.hex() - else: - value_str = '' + value_str = '' else: value_str = '<...>' return ( diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index ad281dad..a3a4ff80 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -315,11 +315,8 @@ class Server(utils.EventEmitter): self.add_service(service) def read_cccd( - self, connection: Optional[Connection], characteristic: Characteristic + self, connection: Connection, characteristic: Characteristic ) -> bytes: - if connection is None: - return bytes([0, 0]) - subscribers = self.subscribers.get(connection.handle) cccd = None if subscribers: diff --git a/bumble/profiles/aics.py b/bumble/profiles/aics.py index ce4be266..f242435e 100644 --- a/bumble/profiles/aics.py +++ b/bumble/profiles/aics.py @@ -198,8 +198,7 @@ class AudioInputControlPoint: audio_input_state: AudioInputState gain_settings_properties: GainSettingsProperties - async def on_write(self, connection: Optional[Connection], value: bytes) -> None: - assert connection + async def on_write(self, connection: Connection, value: bytes) -> None: opcode = AudioInputControlPointOpCode(value[0]) @@ -320,11 +319,10 @@ class AudioInputDescription: audio_input_description: str = "Bluetooth" attribute: Optional[Attribute] = None - def on_read(self, _connection: Optional[Connection]) -> str: + def on_read(self, _connection: Connection) -> str: return self.audio_input_description - async def on_write(self, connection: Optional[Connection], value: str) -> None: - assert connection + async def on_write(self, connection: Connection, value: str) -> None: assert self.attribute self.audio_input_description = value diff --git a/bumble/profiles/ascs.py b/bumble/profiles/ascs.py index 2b001609..a7c2e62a 100644 --- a/bumble/profiles/ascs.py +++ b/bumble/profiles/ascs.py @@ -590,7 +590,7 @@ class AseStateMachine(gatt.Characteristic): # Readonly. Do nothing in the setter. pass - def on_read(self, _: Optional[device.Connection]) -> bytes: + def on_read(self, _: device.Connection) -> bytes: return self.value def __str__(self) -> str: diff --git a/bumble/profiles/asha.py b/bumble/profiles/asha.py index 91668023..1a8f4989 100644 --- a/bumble/profiles/asha.py +++ b/bumble/profiles/asha.py @@ -200,7 +200,7 @@ class AshaService(gatt.TemplateService): # Handler for audio control commands async def _on_audio_control_point_write( - self, connection: Optional[Connection], value: bytes + self, connection: Connection, value: bytes ) -> None: _logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}') opcode = value[0] @@ -247,7 +247,7 @@ class AshaService(gatt.TemplateService): ) # Handler for volume control - def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None: + def _on_volume_write(self, connection: Connection, value: bytes) -> None: _logger.debug(f'--- VOLUME Write:{value[0]}') self.volume = value[0] self.emit(self.EVENT_VOLUME_CHANGED) diff --git a/bumble/profiles/csip.py b/bumble/profiles/csip.py index c5e3f404..07eed775 100644 --- a/bumble/profiles/csip.py +++ b/bumble/profiles/csip.py @@ -164,12 +164,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService): super().__init__(characteristics) - async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes: + async def on_sirk_read(self, connection: device.Connection) -> bytes: if self.set_identity_resolving_key_type == SirkType.PLAINTEXT: sirk_bytes = self.set_identity_resolving_key else: - assert connection - if connection.transport == core.PhysicalTransport.LE: key = await connection.device.get_long_term_key( connection_handle=connection.handle, rand=b'', ediv=0 diff --git a/bumble/profiles/gatt_service.py b/bumble/profiles/gatt_service.py index 539ff952..e2d7a19f 100644 --- a/bumble/profiles/gatt_service.py +++ b/bumble/profiles/gatt_service.py @@ -127,9 +127,7 @@ class GenericAttributeProfileService(gatt.TemplateService): return b'' - def get_database_hash(self, connection: device.Connection | None) -> bytes: - assert connection - + def get_database_hash(self, connection: device.Connection) -> bytes: m = b''.join( [ self.get_attribute_data(attribute) diff --git a/bumble/profiles/hap.py b/bumble/profiles/hap.py index 00748184..8432ef10 100644 --- a/bumble/profiles/hap.py +++ b/bumble/profiles/hap.py @@ -335,9 +335,8 @@ class HearingAccessService(gatt.TemplateService): utils.cancel_on_event(connection, 'disconnection', on_connection_async()) - def _on_read_active_preset_index( - self, __connection__: Optional[Connection] - ) -> bytes: + def _on_read_active_preset_index(self, connection: Connection) -> bytes: + del connection # Unused return bytes([self.active_preset_index]) # TODO this need to be triggered when device is unbonded @@ -345,18 +344,13 @@ class HearingAccessService(gatt.TemplateService): self.preset_changed_operations_history_per_device.pop(addr) async def _on_write_hearing_aid_preset_control_point( - self, connection: Optional[Connection], value: bytes + self, connection: Connection, value: bytes ): - assert connection - opcode = HearingAidPresetControlPointOpcode(value[0]) handler = getattr(self, '_on_' + opcode.name.lower()) await handler(connection, value) - async def _on_read_presets_request( - self, connection: Optional[Connection], value: bytes - ): - assert connection + async def _on_read_presets_request(self, connection: Connection, value: bytes): if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements logging.warning(f'HAS require MTU >= 49: {connection}') @@ -471,10 +465,7 @@ class HearingAccessService(gatt.TemplateService): for connection in self.currently_connected_clients: await self._preset_changed_operation(connection) - async def _on_write_preset_name( - self, connection: Optional[Connection], value: bytes - ): - assert connection + async def _on_write_preset_name(self, connection: Connection, value: bytes): if self.read_presets_request_in_progress: raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) @@ -522,10 +513,7 @@ class HearingAccessService(gatt.TemplateService): for connection in self.currently_connected_clients: await self.notify_active_preset_for_connection(connection) - async def set_active_preset( - self, connection: Optional[Connection], value: bytes - ) -> None: - assert connection + async def set_active_preset(self, connection: Connection, value: bytes) -> None: index = value[1] preset = self.preset_records.get(index, None) if ( @@ -542,16 +530,11 @@ class HearingAccessService(gatt.TemplateService): self.active_preset_index = index await self.notify_active_preset() - async def _on_set_active_preset( - self, connection: Optional[Connection], value: bytes - ): + async def _on_set_active_preset(self, connection: Connection, value: bytes): await self.set_active_preset(connection, value) - async def set_next_or_previous_preset( - self, connection: Optional[Connection], is_previous - ): + async def set_next_or_previous_preset(self, connection: Connection, is_previous): '''Set the next or the previous preset as active''' - assert connection if self.active_preset_index == 0x00: raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) @@ -581,17 +564,17 @@ class HearingAccessService(gatt.TemplateService): await self.notify_active_preset() async def _on_set_next_preset( - self, connection: Optional[Connection], __value__: bytes + self, connection: Connection, __value__: bytes ) -> None: await self.set_next_or_previous_preset(connection, False) async def _on_set_previous_preset( - self, connection: Optional[Connection], __value__: bytes + self, connection: Connection, __value__: bytes ) -> None: await self.set_next_or_previous_preset(connection, True) async def _on_set_active_preset_synchronized_locally( - self, connection: Optional[Connection], value: bytes + self, connection: Connection, value: bytes ): if ( self.server_features.preset_synchronization_support @@ -602,7 +585,7 @@ class HearingAccessService(gatt.TemplateService): # TODO (low priority) inform other server of the change async def _on_set_next_preset_synchronized_locally( - self, connection: Optional[Connection], __value__: bytes + self, connection: Connection, __value__: bytes ): if ( self.server_features.preset_synchronization_support @@ -613,7 +596,7 @@ class HearingAccessService(gatt.TemplateService): # TODO (low priority) inform other server of the change async def _on_set_previous_preset_synchronized_locally( - self, connection: Optional[Connection], __value__: bytes + self, connection: Connection, __value__: bytes ): if ( self.server_features.preset_synchronization_support diff --git a/bumble/profiles/mcp.py b/bumble/profiles/mcp.py index 68cac933..e2ca8b19 100644 --- a/bumble/profiles/mcp.py +++ b/bumble/profiles/mcp.py @@ -287,11 +287,8 @@ class MediaControlService(gatt.TemplateService): ) async def on_media_control_point( - self, connection: Optional[device.Connection], data: bytes + self, connection: device.Connection, data: bytes ) -> None: - if not connection: - raise core.InvalidStateError() - opcode = MediaControlPointOpcode(data[0]) await connection.device.notify_subscriber( diff --git a/bumble/profiles/vcs.py b/bumble/profiles/vcs.py index 54d7bbe5..2d1424e5 100644 --- a/bumble/profiles/vcs.py +++ b/bumble/profiles/vcs.py @@ -146,14 +146,12 @@ class VolumeControlService(gatt.TemplateService): included_services=list(included_services), ) - def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes: + def _on_read_volume_state(self, _connection: device.Connection) -> bytes: return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter)) def _on_write_volume_control_point( - self, connection: Optional[device.Connection], value: bytes + self, connection: device.Connection, value: bytes ) -> None: - assert connection - opcode = VolumeControlPointOpcode(value[0]) change_counter = value[1] diff --git a/bumble/profiles/vocs.py b/bumble/profiles/vocs.py index 13ae4495..7ce5ddd7 100644 --- a/bumble/profiles/vocs.py +++ b/bumble/profiles/vocs.py @@ -86,7 +86,7 @@ class VolumeOffsetState: assert self.attribute is not None await connection.device.notify_subscribers(attribute=self.attribute) - def on_read(self, _connection: Optional[Connection]) -> bytes: + def on_read(self, _connection: Connection) -> bytes: return bytes(self) @@ -103,11 +103,10 @@ class VocsAudioLocation: audio_location = AudioLocation(struct.unpack(' bytes: + def on_read(self, _connection: Connection) -> bytes: return bytes(self) - async def on_write(self, connection: Optional[Connection], value: bytes) -> None: - assert connection + async def on_write(self, connection: Connection, value: bytes) -> None: assert self.attribute self.audio_location = AudioLocation(int.from_bytes(value, 'little')) @@ -118,8 +117,7 @@ class VocsAudioLocation: class VolumeOffsetControlPoint: volume_offset_state: VolumeOffsetState - async def on_write(self, connection: Optional[Connection], value: bytes) -> None: - assert connection + async def on_write(self, connection: Connection, value: bytes) -> None: opcode = value[0] if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET: @@ -159,11 +157,10 @@ class AudioOutputDescription: def __bytes__(self) -> bytes: return self.audio_output_description.encode('utf-8') - def on_read(self, _connection: Optional[Connection]) -> bytes: + def on_read(self, _connection: Connection) -> bytes: return bytes(self) - async def on_write(self, connection: Optional[Connection], value: bytes) -> None: - assert connection + async def on_write(self, connection: Connection, value: bytes) -> None: assert self.attribute self.audio_output_description = value.decode('utf-8') diff --git a/tests/gatt_test.py b/tests/gatt_test.py index b6f46571..c27f2d35 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -136,9 +136,9 @@ async def test_characteristic_encoding(): Characteristic.READABLE, 123, ) - x = await c.read_value(None) + x = await c.read_value(Mock()) assert x == bytes([123]) - await c.write_value(None, bytes([122])) + await c.write_value(Mock(), bytes([122])) assert c.value == 122 class FooProxy(CharacteristicProxy): @@ -334,7 +334,7 @@ async def test_CharacteristicAdapter() -> None: ) v = bytes([3, 4, 5]) - await c.write_value(None, v) + await c.write_value(Mock(), v) assert c.value == v # Simple delegated adapter @@ -342,11 +342,11 @@ async def test_CharacteristicAdapter() -> None: c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)) ) - delegated_value = await delegated.read_value(None) + delegated_value = await delegated.read_value(Mock()) assert delegated_value == bytes(reversed(v)) delegated_value2 = bytes([3, 4, 5]) - await delegated.write_value(None, delegated_value2) + await delegated.write_value(Mock(), delegated_value2) assert delegated.value == bytes(reversed(delegated_value2)) # Packed adapter with single element format @@ -355,10 +355,10 @@ async def test_CharacteristicAdapter() -> None: c.value = packed_value_ref packed = PackedCharacteristicAdapter(c, '>H') - packed_value_read = await packed.read_value(None) + packed_value_read = await packed.read_value(Mock()) assert packed_value_read == packed_value_bytes c.value = b'' - await packed.write_value(None, packed_value_bytes) + await packed.write_value(Mock(), packed_value_bytes) assert packed.value == packed_value_ref # Packed adapter with multi-element format @@ -368,10 +368,10 @@ async def test_CharacteristicAdapter() -> None: c.value = (v1, v2) packed_multi = PackedCharacteristicAdapter(c, '>HH') - packed_multi_read_value = await packed_multi.read_value(None) + packed_multi_read_value = await packed_multi.read_value(Mock()) assert packed_multi_read_value == packed_multi_value_bytes packed_multi.value = b'' - await packed_multi.write_value(None, packed_multi_value_bytes) + await packed_multi.write_value(Mock(), packed_multi_value_bytes) assert packed_multi.value == (v1, v2) # Mapped adapter @@ -382,10 +382,10 @@ async def test_CharacteristicAdapter() -> None: c.value = mapped packed_mapped = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2')) - packed_mapped_read_value = await packed_mapped.read_value(None) + packed_mapped_read_value = await packed_mapped.read_value(Mock()) assert packed_mapped_read_value == packed_mapped_value_bytes c.value = b'' - await packed_mapped.write_value(None, packed_mapped_value_bytes) + await packed_mapped.write_value(Mock(), packed_mapped_value_bytes) assert packed_mapped.value == mapped # UTF-8 adapter @@ -394,10 +394,10 @@ async def test_CharacteristicAdapter() -> None: c.value = string_value string_c = UTF8CharacteristicAdapter(c) - string_read_value = await string_c.read_value(None) + string_read_value = await string_c.read_value(Mock()) assert string_read_value == string_value_bytes c.value = b'' - await string_c.write_value(None, string_value_bytes) + await string_c.write_value(Mock(), string_value_bytes) assert string_c.value == string_value # Class adapter @@ -419,10 +419,10 @@ async def test_CharacteristicAdapter() -> None: c.value = class_value class_c = SerializableCharacteristicAdapter(c, BlaBla) - class_read_value = await class_c.read_value(None) + class_read_value = await class_c.read_value(Mock()) assert class_read_value == class_value_bytes class_c.value = b'' - await class_c.write_value(None, class_value_bytes) + await class_c.write_value(Mock(), class_value_bytes) assert isinstance(class_c.value, BlaBla) assert class_c.value.a == class_value.a assert class_c.value.b == class_value.b @@ -436,10 +436,10 @@ async def test_CharacteristicAdapter() -> None: enum_value_bytes = int(enum_value).to_bytes(3, 'big') c.value = enum_value enum_c = EnumCharacteristicAdapter(c, MyEnum, 3, 'big') - enum_read_value = await enum_c.read_value(None) + enum_read_value = await enum_c.read_value(Mock()) assert enum_read_value == enum_value_bytes enum_c.value = b'' - await enum_c.write_value(None, enum_value_bytes) + await enum_c.write_value(Mock(), enum_value_bytes) assert isinstance(enum_c.value, MyEnum) assert enum_c.value == enum_value @@ -1254,7 +1254,7 @@ Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), READ) Service(handle=0x0006, end=0x000D, uuid=UUID-16:1801 (Generic Attribute)) CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=UUID-16:2A05 (Service Changed), INDICATE) Characteristic(handle=0x0008, end=0x0009, uuid=UUID-16:2A05 (Service Changed), INDICATE) -Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000) +Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=) CharacteristicDeclaration(handle=0x000A, value_handle=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE) Characteristic(handle=0x000B, end=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE) CharacteristicDeclaration(handle=0x000C, value_handle=0x000D, uuid=UUID-16:2B2A (Database Hash), READ) @@ -1262,7 +1262,7 @@ Characteristic(handle=0x000D, end=0x000D, uuid=UUID-16:2B2A (Database Hash), REA Service(handle=0x000E, end=0x0011, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829) CharacteristicDeclaration(handle=0x000F, value_handle=0x0010, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY) Characteristic(handle=0x0010, end=0x0011, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY) -Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)""" +Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=)""" )