Replace Optional[Connection] att parameter type

This commit is contained in:
zxzxwu
2025-05-18 07:49:11 +00:00
parent 7b7afc7179
commit 84f7cad678
13 changed files with 59 additions and 99 deletions

View File

@@ -198,8 +198,7 @@ class AudioInputControlPoint:
audio_input_state: AudioInputState
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = AudioInputControlPointOpCode(value[0])
@@ -320,11 +319,10 @@ class AudioInputDescription:
audio_input_description: str = "Bluetooth"
attribute: Optional[Attribute] = None
def on_read(self, _connection: Optional[Connection]) -> str:
def on_read(self, _connection: Connection) -> str:
return self.audio_input_description
async def on_write(self, connection: Optional[Connection], value: str) -> None:
assert connection
async def on_write(self, connection: Connection, value: str) -> None:
assert self.attribute
self.audio_input_description = value

View File

@@ -590,7 +590,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
def on_read(self, _: device.Connection) -> bytes:
return self.value
def __str__(self) -> str:

View File

@@ -200,7 +200,7 @@ class AshaService(gatt.TemplateService):
# Handler for audio control commands
async def _on_audio_control_point_write(
self, connection: Optional[Connection], value: bytes
self, connection: Connection, value: bytes
) -> None:
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
@@ -247,7 +247,7 @@ class AshaService(gatt.TemplateService):
)
# Handler for volume control
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
def _on_volume_write(self, connection: Connection, value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0]
self.emit(self.EVENT_VOLUME_CHANGED)

View File

@@ -164,12 +164,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
async def on_sirk_read(self, connection: device.Connection) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.PhysicalTransport.LE:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0

View File

@@ -127,9 +127,7 @@ class GenericAttributeProfileService(gatt.TemplateService):
return b''
def get_database_hash(self, connection: device.Connection | None) -> bytes:
assert connection
def get_database_hash(self, connection: device.Connection) -> bytes:
m = b''.join(
[
self.get_attribute_data(attribute)

View File

@@ -335,9 +335,8 @@ class HearingAccessService(gatt.TemplateService):
utils.cancel_on_event(connection, 'disconnection', on_connection_async())
def _on_read_active_preset_index(
self, __connection__: Optional[Connection]
) -> bytes:
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
del connection # Unused
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
@@ -345,18 +344,13 @@ class HearingAccessService(gatt.TemplateService):
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Optional[Connection], value: bytes
self, connection: Connection, value: bytes
):
assert connection
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(
self, connection: Optional[Connection], value: bytes
):
assert connection
async def _on_read_presets_request(self, connection: Connection, value: bytes):
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
@@ -471,10 +465,7 @@ class HearingAccessService(gatt.TemplateService):
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(
self, connection: Optional[Connection], value: bytes
):
assert connection
async def _on_write_preset_name(self, connection: Connection, value: bytes):
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -522,10 +513,7 @@ class HearingAccessService(gatt.TemplateService):
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(
self, connection: Optional[Connection], value: bytes
) -> None:
assert connection
async def set_active_preset(self, connection: Connection, value: bytes) -> None:
index = value[1]
preset = self.preset_records.get(index, None)
if (
@@ -542,16 +530,11 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(
self, connection: Optional[Connection], value: bytes
):
async def _on_set_active_preset(self, connection: Connection, value: bytes):
await self.set_active_preset(connection, value)
async def set_next_or_previous_preset(
self, connection: Optional[Connection], is_previous
):
async def set_next_or_previous_preset(self, connection: Connection, is_previous):
'''Set the next or the previous preset as active'''
assert connection
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
@@ -581,17 +564,17 @@ class HearingAccessService(gatt.TemplateService):
await self.notify_active_preset()
async def _on_set_next_preset(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, False)
async def _on_set_previous_preset(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, True)
async def _on_set_active_preset_synchronized_locally(
self, connection: Optional[Connection], value: bytes
self, connection: Connection, value: bytes
):
if (
self.server_features.preset_synchronization_support
@@ -602,7 +585,7 @@ class HearingAccessService(gatt.TemplateService):
# TODO (low priority) inform other server of the change
async def _on_set_next_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
):
if (
self.server_features.preset_synchronization_support
@@ -613,7 +596,7 @@ class HearingAccessService(gatt.TemplateService):
# TODO (low priority) inform other server of the change
async def _on_set_previous_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
self, connection: Connection, __value__: bytes
):
if (
self.server_features.preset_synchronization_support

View File

@@ -287,11 +287,8 @@ class MediaControlService(gatt.TemplateService):
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
self, connection: device.Connection, data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(

View File

@@ -146,14 +146,12 @@ class VolumeControlService(gatt.TemplateService):
included_services=list(included_services),
)
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
def _on_read_volume_state(self, _connection: device.Connection) -> bytes:
return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter))
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
self, connection: device.Connection, value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]

View File

@@ -86,7 +86,7 @@ class VolumeOffsetState:
assert self.attribute is not None
await connection.device.notify_subscribers(attribute=self.attribute)
def on_read(self, _connection: Optional[Connection]) -> bytes:
def on_read(self, _connection: Connection) -> bytes:
return bytes(self)
@@ -103,11 +103,10 @@ class VocsAudioLocation:
audio_location = AudioLocation(struct.unpack('<I', data)[0])
return cls(audio_location)
def on_read(self, _connection: Optional[Connection]) -> bytes:
def on_read(self, _connection: Connection) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
assert self.attribute
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
@@ -118,8 +117,7 @@ class VocsAudioLocation:
class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
@@ -159,11 +157,10 @@ class AudioOutputDescription:
def __bytes__(self) -> bytes:
return self.audio_output_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
def on_read(self, _connection: Connection) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
async def on_write(self, connection: Connection, value: bytes) -> None:
assert self.attribute
self.audio_output_description = value.decode('utf-8')