Merge pull request #696 from zxzxwu/att

Replace Optional[Connection] AttributeValue parameter type
This commit is contained in:
zxzxwu
2025-05-19 21:17:11 +08:00
committed by GitHub
13 changed files with 59 additions and 99 deletions

View File

@@ -770,27 +770,25 @@ class AttributeValue(Generic[_T]):
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], _T],
Callable[[Optional[Connection]], Awaitable[_T]],
Callable[[Connection], _T],
Callable[[Connection], Awaitable[_T]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], _T], None],
Callable[[Optional[Connection], _T], Awaitable[None]],
Callable[[Connection, _T], None],
Callable[[Connection, _T], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[_T, Awaitable[_T]]:
def read(self, connection: Connection) -> Union[_T, Awaitable[_T]]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(connection)
def write(
self, connection: Optional[Connection], value: _T
) -> Union[Awaitable[None], None]:
def write(self, connection: Connection, value: _T) -> Union[Awaitable[None], None]:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(connection, value)
@@ -871,7 +869,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T:
return value # type: ignore
async def read_value(self, connection: Optional[Connection]) -> bytes:
async def read_value(self, connection: Connection) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -913,7 +911,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Optional[Connection], value: bytes) -> None:
async def write_value(self, connection: Connection, value: bytes) -> None:
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None

View File

@@ -579,11 +579,7 @@ class Descriptor(Attribute):
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
value_str = '<dynamic>'
else:
value_str = '<...>'
return (

View File

@@ -315,11 +315,8 @@ class Server(utils.EventEmitter):
self.add_service(service)
def read_cccd(
self, connection: Optional[Connection], characteristic: Characteristic
self, connection: Connection, characteristic: Characteristic
) -> bytes:
if connection is None:
return bytes([0, 0])
subscribers = self.subscribers.get(connection.handle)
cccd = None
if subscribers:

View File

@@ -198,8 +198,7 @@ class AudioInputControlPoint:
audio_input_state: AudioInputState
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = AudioInputControlPointOpCode(value[0])
@@ -320,11 +319,10 @@ class AudioInputDescription:
audio_input_description: str = "Bluetooth"
attribute: Optional[Attribute] = None
def on_read(self, _connection: Optional[Connection]) -> str:
def on_read(self, _connection: Connection) -> str:
return self.audio_input_description
async def on_write(self, connection: Optional[Connection], value: str) -> None:
assert connection
async def on_write(self, connection: Connection, value: str) -> None:
assert self.attribute
self.audio_input_description = value

View File

@@ -590,7 +590,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
def on_read(self, _: device.Connection) -> bytes:
return self.value
def __str__(self) -> str:

View File

@@ -200,7 +200,7 @@ class AshaService(gatt.TemplateService):
# Handler for audio control commands
async def _on_audio_control_point_write(
self, connection: Optional[Connection], value: bytes
self, connection: Connection, value: bytes
) -> None:
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
@@ -247,7 +247,7 @@ class AshaService(gatt.TemplateService):
)
# Handler for volume control
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
def _on_volume_write(self, connection: Connection, value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0]
self.emit(self.EVENT_VOLUME_CHANGED)

View File

@@ -164,12 +164,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
async def on_sirk_read(self, connection: device.Connection) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.PhysicalTransport.LE:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0

View File

@@ -127,9 +127,7 @@ class GenericAttributeProfileService(gatt.TemplateService):
return b''
def get_database_hash(self, connection: device.Connection | None) -> bytes:
assert connection
def get_database_hash(self, connection: device.Connection) -> bytes:
m = b''.join(
[
self.get_attribute_data(attribute)

View File

@@ -335,9 +335,8 @@ class HearingAccessService(gatt.TemplateService):
utils.cancel_on_event(connection, 'disconnection', on_connection_async())
def _on_read_active_preset_index(
self, __connection__: Optional[Connection]
) -> bytes:
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
del connection # Unused
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
@@ -345,18 +344,13 @@ class HearingAccessService(gatt.TemplateService):
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Optional[Connection], value: bytes
self, connection: Connection, value: bytes
):
assert connection
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(
self, connection: Optional[Connection], value: bytes
):
assert connection
async def _on_read_presets_request(self, connection: Connection, value: bytes):
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
@@ -471,10 +465,7 @@ class HearingAccessService(gatt.TemplateService):
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(
self, connection: Optional[Connection], value: bytes
):
assert connection
async def _on_write_preset_name(self, connection: Connection, value: bytes):
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -522,10 +513,7 @@ class HearingAccessService(gatt.TemplateService):
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(
self, connection: Optional[Connection], value: bytes
) -> None:
assert connection
async def set_active_preset(self, connection: Connection, value: bytes) -> None:
index = value[1]
preset = self.preset_records.get(index, None)
if (
@@ -542,16 +530,11 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(
self, connection: Optional[Connection], value: bytes
):
async def _on_set_active_preset(self, connection: Connection, value: bytes):
await self.set_active_preset(connection, value)
async def set_next_or_previous_preset(
self, connection: Optional[Connection], is_previous
):
async def set_next_or_previous_preset(self, connection: Connection, is_previous):
'''Set the next or the previous preset as active'''
assert connection
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
@@ -581,17 +564,17 @@ class HearingAccessService(gatt.TemplateService):
await self.notify_active_preset()
async def _on_set_next_preset(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, False)
async def _on_set_previous_preset(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, True)
async def _on_set_active_preset_synchronized_locally(
self, connection: Optional[Connection], value: bytes
self, connection: Connection, value: bytes
):
if (
self.server_features.preset_synchronization_support
@@ -602,7 +585,7 @@ class HearingAccessService(gatt.TemplateService):
# TODO (low priority) inform other server of the change
async def _on_set_next_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
):
if (
self.server_features.preset_synchronization_support
@@ -613,7 +596,7 @@ class HearingAccessService(gatt.TemplateService):
# TODO (low priority) inform other server of the change
async def _on_set_previous_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
):
if (
self.server_features.preset_synchronization_support

View File

@@ -287,11 +287,8 @@ class MediaControlService(gatt.TemplateService):
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
self, connection: device.Connection, data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(

View File

@@ -146,14 +146,12 @@ class VolumeControlService(gatt.TemplateService):
included_services=list(included_services),
)
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
def _on_read_volume_state(self, _connection: device.Connection) -> bytes:
return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter))
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
self, connection: device.Connection, value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]

View File

@@ -86,7 +86,7 @@ class VolumeOffsetState:
assert self.attribute is not None
await connection.device.notify_subscribers(attribute=self.attribute)
def on_read(self, _connection: Optional[Connection]) -> bytes:
def on_read(self, _connection: Connection) -> bytes:
return bytes(self)
@@ -103,11 +103,10 @@ class VocsAudioLocation:
audio_location = AudioLocation(struct.unpack('<I', data)[0])
return cls(audio_location)
def on_read(self, _connection: Optional[Connection]) -> bytes:
def on_read(self, _connection: Connection) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
assert self.attribute
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
@@ -118,8 +117,7 @@ class VocsAudioLocation:
class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
@@ -159,11 +157,10 @@ class AudioOutputDescription:
def __bytes__(self) -> bytes:
return self.audio_output_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
def on_read(self, _connection: Connection) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
assert self.attribute
self.audio_output_description = value.decode('utf-8')

View File

@@ -136,9 +136,9 @@ async def test_characteristic_encoding():
Characteristic.READABLE,
123,
)
x = await c.read_value(None)
x = await c.read_value(Mock())
assert x == bytes([123])
await c.write_value(None, bytes([122]))
await c.write_value(Mock(), bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
@@ -334,7 +334,7 @@ async def test_CharacteristicAdapter() -> None:
)
v = bytes([3, 4, 5])
await c.write_value(None, v)
await c.write_value(Mock(), v)
assert c.value == v
# Simple delegated adapter
@@ -342,11 +342,11 @@ async def test_CharacteristicAdapter() -> None:
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
)
delegated_value = await delegated.read_value(None)
delegated_value = await delegated.read_value(Mock())
assert delegated_value == bytes(reversed(v))
delegated_value2 = bytes([3, 4, 5])
await delegated.write_value(None, delegated_value2)
await delegated.write_value(Mock(), delegated_value2)
assert delegated.value == bytes(reversed(delegated_value2))
# Packed adapter with single element format
@@ -355,10 +355,10 @@ async def test_CharacteristicAdapter() -> None:
c.value = packed_value_ref
packed = PackedCharacteristicAdapter(c, '>H')
packed_value_read = await packed.read_value(None)
packed_value_read = await packed.read_value(Mock())
assert packed_value_read == packed_value_bytes
c.value = b''
await packed.write_value(None, packed_value_bytes)
await packed.write_value(Mock(), packed_value_bytes)
assert packed.value == packed_value_ref
# Packed adapter with multi-element format
@@ -368,10 +368,10 @@ async def test_CharacteristicAdapter() -> None:
c.value = (v1, v2)
packed_multi = PackedCharacteristicAdapter(c, '>HH')
packed_multi_read_value = await packed_multi.read_value(None)
packed_multi_read_value = await packed_multi.read_value(Mock())
assert packed_multi_read_value == packed_multi_value_bytes
packed_multi.value = b''
await packed_multi.write_value(None, packed_multi_value_bytes)
await packed_multi.write_value(Mock(), packed_multi_value_bytes)
assert packed_multi.value == (v1, v2)
# Mapped adapter
@@ -382,10 +382,10 @@ async def test_CharacteristicAdapter() -> None:
c.value = mapped
packed_mapped = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
packed_mapped_read_value = await packed_mapped.read_value(None)
packed_mapped_read_value = await packed_mapped.read_value(Mock())
assert packed_mapped_read_value == packed_mapped_value_bytes
c.value = b''
await packed_mapped.write_value(None, packed_mapped_value_bytes)
await packed_mapped.write_value(Mock(), packed_mapped_value_bytes)
assert packed_mapped.value == mapped
# UTF-8 adapter
@@ -394,10 +394,10 @@ async def test_CharacteristicAdapter() -> None:
c.value = string_value
string_c = UTF8CharacteristicAdapter(c)
string_read_value = await string_c.read_value(None)
string_read_value = await string_c.read_value(Mock())
assert string_read_value == string_value_bytes
c.value = b''
await string_c.write_value(None, string_value_bytes)
await string_c.write_value(Mock(), string_value_bytes)
assert string_c.value == string_value
# Class adapter
@@ -419,10 +419,10 @@ async def test_CharacteristicAdapter() -> None:
c.value = class_value
class_c = SerializableCharacteristicAdapter(c, BlaBla)
class_read_value = await class_c.read_value(None)
class_read_value = await class_c.read_value(Mock())
assert class_read_value == class_value_bytes
class_c.value = b''
await class_c.write_value(None, class_value_bytes)
await class_c.write_value(Mock(), class_value_bytes)
assert isinstance(class_c.value, BlaBla)
assert class_c.value.a == class_value.a
assert class_c.value.b == class_value.b
@@ -436,10 +436,10 @@ async def test_CharacteristicAdapter() -> None:
enum_value_bytes = int(enum_value).to_bytes(3, 'big')
c.value = enum_value
enum_c = EnumCharacteristicAdapter(c, MyEnum, 3, 'big')
enum_read_value = await enum_c.read_value(None)
enum_read_value = await enum_c.read_value(Mock())
assert enum_read_value == enum_value_bytes
enum_c.value = b''
await enum_c.write_value(None, enum_value_bytes)
await enum_c.write_value(Mock(), enum_value_bytes)
assert isinstance(enum_c.value, MyEnum)
assert enum_c.value == enum_value
@@ -1254,7 +1254,7 @@ Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), READ)
Service(handle=0x0006, end=0x000D, uuid=UUID-16:1801 (Generic Attribute))
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=UUID-16:2A05 (Service Changed), INDICATE)
Characteristic(handle=0x0008, end=0x0009, uuid=UUID-16:2A05 (Service Changed), INDICATE)
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=<dynamic>)
CharacteristicDeclaration(handle=0x000A, value_handle=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE)
Characteristic(handle=0x000B, end=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE)
CharacteristicDeclaration(handle=0x000C, value_handle=0x000D, uuid=UUID-16:2B2A (Database Hash), READ)
@@ -1262,7 +1262,7 @@ Characteristic(handle=0x000D, end=0x000D, uuid=UUID-16:2B2A (Database Hash), REA
Service(handle=0x000E, end=0x0011, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
CharacteristicDeclaration(handle=0x000F, value_handle=0x0010, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
Characteristic(handle=0x0010, end=0x0011, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=<dynamic>)"""
)