forked from auracaster/bumble_mirror
Merge pull request #696 from zxzxwu/att
Replace Optional[Connection] AttributeValue parameter type
This commit is contained in:
@@ -770,27 +770,25 @@ class AttributeValue(Generic[_T]):
|
||||
def __init__(
|
||||
self,
|
||||
read: Union[
|
||||
Callable[[Optional[Connection]], _T],
|
||||
Callable[[Optional[Connection]], Awaitable[_T]],
|
||||
Callable[[Connection], _T],
|
||||
Callable[[Connection], Awaitable[_T]],
|
||||
None,
|
||||
] = None,
|
||||
write: Union[
|
||||
Callable[[Optional[Connection], _T], None],
|
||||
Callable[[Optional[Connection], _T], Awaitable[None]],
|
||||
Callable[[Connection, _T], None],
|
||||
Callable[[Connection, _T], Awaitable[None]],
|
||||
None,
|
||||
] = None,
|
||||
):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection: Optional[Connection]) -> Union[_T, Awaitable[_T]]:
|
||||
def read(self, connection: Connection) -> Union[_T, Awaitable[_T]]:
|
||||
if self._read is None:
|
||||
raise InvalidOperationError('AttributeValue has no read function')
|
||||
return self._read(connection)
|
||||
|
||||
def write(
|
||||
self, connection: Optional[Connection], value: _T
|
||||
) -> Union[Awaitable[None], None]:
|
||||
def write(self, connection: Connection, value: _T) -> Union[Awaitable[None], None]:
|
||||
if self._write is None:
|
||||
raise InvalidOperationError('AttributeValue has no write function')
|
||||
return self._write(connection, value)
|
||||
@@ -871,7 +869,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
def decode_value(self, value: bytes) -> _T:
|
||||
return value # type: ignore
|
||||
|
||||
async def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
async def read_value(self, connection: Connection) -> bytes:
|
||||
if (
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
@@ -913,7 +911,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
|
||||
return b'' if value is None else self.encode_value(value)
|
||||
|
||||
async def write_value(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
async def write_value(self, connection: Connection, value: bytes) -> None:
|
||||
if (
|
||||
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
|
||||
@@ -579,11 +579,7 @@ class Descriptor(Attribute):
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
value = self.value.read(None)
|
||||
if isinstance(value, bytes):
|
||||
value_str = value.hex()
|
||||
else:
|
||||
value_str = '<async>'
|
||||
value_str = '<dynamic>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
return (
|
||||
|
||||
@@ -315,11 +315,8 @@ class Server(utils.EventEmitter):
|
||||
self.add_service(service)
|
||||
|
||||
def read_cccd(
|
||||
self, connection: Optional[Connection], characteristic: Characteristic
|
||||
self, connection: Connection, characteristic: Characteristic
|
||||
) -> bytes:
|
||||
if connection is None:
|
||||
return bytes([0, 0])
|
||||
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
cccd = None
|
||||
if subscribers:
|
||||
|
||||
@@ -198,8 +198,7 @@ class AudioInputControlPoint:
|
||||
audio_input_state: AudioInputState
|
||||
gain_settings_properties: GainSettingsProperties
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
async def on_write(self, connection: Connection, value: bytes) -> None:
|
||||
|
||||
opcode = AudioInputControlPointOpCode(value[0])
|
||||
|
||||
@@ -320,11 +319,10 @@ class AudioInputDescription:
|
||||
audio_input_description: str = "Bluetooth"
|
||||
attribute: Optional[Attribute] = None
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> str:
|
||||
def on_read(self, _connection: Connection) -> str:
|
||||
return self.audio_input_description
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: str) -> None:
|
||||
assert connection
|
||||
async def on_write(self, connection: Connection, value: str) -> None:
|
||||
assert self.attribute
|
||||
|
||||
self.audio_input_description = value
|
||||
|
||||
@@ -590,7 +590,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
# Readonly. Do nothing in the setter.
|
||||
pass
|
||||
|
||||
def on_read(self, _: Optional[device.Connection]) -> bytes:
|
||||
def on_read(self, _: device.Connection) -> bytes:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -200,7 +200,7 @@ class AshaService(gatt.TemplateService):
|
||||
|
||||
# Handler for audio control commands
|
||||
async def _on_audio_control_point_write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
) -> None:
|
||||
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
@@ -247,7 +247,7 @@ class AshaService(gatt.TemplateService):
|
||||
)
|
||||
|
||||
# Handler for volume control
|
||||
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
def _on_volume_write(self, connection: Connection, value: bytes) -> None:
|
||||
_logger.debug(f'--- VOLUME Write:{value[0]}')
|
||||
self.volume = value[0]
|
||||
self.emit(self.EVENT_VOLUME_CHANGED)
|
||||
|
||||
@@ -164,12 +164,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
|
||||
async def on_sirk_read(self, connection: device.Connection) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
sirk_bytes = self.set_identity_resolving_key
|
||||
else:
|
||||
assert connection
|
||||
|
||||
if connection.transport == core.PhysicalTransport.LE:
|
||||
key = await connection.device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
|
||||
@@ -127,9 +127,7 @@ class GenericAttributeProfileService(gatt.TemplateService):
|
||||
|
||||
return b''
|
||||
|
||||
def get_database_hash(self, connection: device.Connection | None) -> bytes:
|
||||
assert connection
|
||||
|
||||
def get_database_hash(self, connection: device.Connection) -> bytes:
|
||||
m = b''.join(
|
||||
[
|
||||
self.get_attribute_data(attribute)
|
||||
|
||||
@@ -335,9 +335,8 @@ class HearingAccessService(gatt.TemplateService):
|
||||
|
||||
utils.cancel_on_event(connection, 'disconnection', on_connection_async())
|
||||
|
||||
def _on_read_active_preset_index(
|
||||
self, __connection__: Optional[Connection]
|
||||
) -> bytes:
|
||||
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
|
||||
del connection # Unused
|
||||
return bytes([self.active_preset_index])
|
||||
|
||||
# TODO this need to be triggered when device is unbonded
|
||||
@@ -345,18 +344,13 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.preset_changed_operations_history_per_device.pop(addr)
|
||||
|
||||
async def _on_write_hearing_aid_preset_control_point(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
assert connection
|
||||
|
||||
opcode = HearingAidPresetControlPointOpcode(value[0])
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
await handler(connection, value)
|
||||
|
||||
async def _on_read_presets_request(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
async def _on_read_presets_request(self, connection: Connection, value: bytes):
|
||||
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
|
||||
logging.warning(f'HAS require MTU >= 49: {connection}')
|
||||
|
||||
@@ -471,10 +465,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
for connection in self.currently_connected_clients:
|
||||
await self._preset_changed_operation(connection)
|
||||
|
||||
async def _on_write_preset_name(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
async def _on_write_preset_name(self, connection: Connection, value: bytes):
|
||||
|
||||
if self.read_presets_request_in_progress:
|
||||
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
|
||||
@@ -522,10 +513,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
for connection in self.currently_connected_clients:
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
|
||||
async def set_active_preset(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
async def set_active_preset(self, connection: Connection, value: bytes) -> None:
|
||||
index = value[1]
|
||||
preset = self.preset_records.get(index, None)
|
||||
if (
|
||||
@@ -542,16 +530,11 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.active_preset_index = index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_active_preset(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
async def _on_set_active_preset(self, connection: Connection, value: bytes):
|
||||
await self.set_active_preset(connection, value)
|
||||
|
||||
async def set_next_or_previous_preset(
|
||||
self, connection: Optional[Connection], is_previous
|
||||
):
|
||||
async def set_next_or_previous_preset(self, connection: Connection, is_previous):
|
||||
'''Set the next or the previous preset as active'''
|
||||
assert connection
|
||||
|
||||
if self.active_preset_index == 0x00:
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
@@ -581,17 +564,17 @@ class HearingAccessService(gatt.TemplateService):
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_next_preset(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, __value__: bytes
|
||||
) -> None:
|
||||
await self.set_next_or_previous_preset(connection, False)
|
||||
|
||||
async def _on_set_previous_preset(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, __value__: bytes
|
||||
) -> None:
|
||||
await self.set_next_or_previous_preset(connection, True)
|
||||
|
||||
async def _on_set_active_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
@@ -602,7 +585,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
# TODO (low priority) inform other server of the change
|
||||
|
||||
async def _on_set_next_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, __value__: bytes
|
||||
):
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
@@ -613,7 +596,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
# TODO (low priority) inform other server of the change
|
||||
|
||||
async def _on_set_previous_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, __value__: bytes
|
||||
):
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
|
||||
@@ -287,11 +287,8 @@ class MediaControlService(gatt.TemplateService):
|
||||
)
|
||||
|
||||
async def on_media_control_point(
|
||||
self, connection: Optional[device.Connection], data: bytes
|
||||
self, connection: device.Connection, data: bytes
|
||||
) -> None:
|
||||
if not connection:
|
||||
raise core.InvalidStateError()
|
||||
|
||||
opcode = MediaControlPointOpcode(data[0])
|
||||
|
||||
await connection.device.notify_subscriber(
|
||||
|
||||
@@ -146,14 +146,12 @@ class VolumeControlService(gatt.TemplateService):
|
||||
included_services=list(included_services),
|
||||
)
|
||||
|
||||
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
|
||||
def _on_read_volume_state(self, _connection: device.Connection) -> bytes:
|
||||
return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter))
|
||||
|
||||
def _on_write_volume_control_point(
|
||||
self, connection: Optional[device.Connection], value: bytes
|
||||
self, connection: device.Connection, value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = VolumeControlPointOpcode(value[0])
|
||||
change_counter = value[1]
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class VolumeOffsetState:
|
||||
assert self.attribute is not None
|
||||
await connection.device.notify_subscribers(attribute=self.attribute)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
def on_read(self, _connection: Connection) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
|
||||
@@ -103,11 +103,10 @@ class VocsAudioLocation:
|
||||
audio_location = AudioLocation(struct.unpack('<I', data)[0])
|
||||
return cls(audio_location)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
def on_read(self, _connection: Connection) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
async def on_write(self, connection: Connection, value: bytes) -> None:
|
||||
assert self.attribute
|
||||
|
||||
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
|
||||
@@ -118,8 +117,7 @@ class VocsAudioLocation:
|
||||
class VolumeOffsetControlPoint:
|
||||
volume_offset_state: VolumeOffsetState
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
async def on_write(self, connection: Connection, value: bytes) -> None:
|
||||
|
||||
opcode = value[0]
|
||||
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
|
||||
@@ -159,11 +157,10 @@ class AudioOutputDescription:
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.audio_output_description.encode('utf-8')
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
def on_read(self, _connection: Connection) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
async def on_write(self, connection: Connection, value: bytes) -> None:
|
||||
assert self.attribute
|
||||
|
||||
self.audio_output_description = value.decode('utf-8')
|
||||
|
||||
@@ -136,9 +136,9 @@ async def test_characteristic_encoding():
|
||||
Characteristic.READABLE,
|
||||
123,
|
||||
)
|
||||
x = await c.read_value(None)
|
||||
x = await c.read_value(Mock())
|
||||
assert x == bytes([123])
|
||||
await c.write_value(None, bytes([122]))
|
||||
await c.write_value(Mock(), bytes([122]))
|
||||
assert c.value == 122
|
||||
|
||||
class FooProxy(CharacteristicProxy):
|
||||
@@ -334,7 +334,7 @@ async def test_CharacteristicAdapter() -> None:
|
||||
)
|
||||
|
||||
v = bytes([3, 4, 5])
|
||||
await c.write_value(None, v)
|
||||
await c.write_value(Mock(), v)
|
||||
assert c.value == v
|
||||
|
||||
# Simple delegated adapter
|
||||
@@ -342,11 +342,11 @@ async def test_CharacteristicAdapter() -> None:
|
||||
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
|
||||
)
|
||||
|
||||
delegated_value = await delegated.read_value(None)
|
||||
delegated_value = await delegated.read_value(Mock())
|
||||
assert delegated_value == bytes(reversed(v))
|
||||
|
||||
delegated_value2 = bytes([3, 4, 5])
|
||||
await delegated.write_value(None, delegated_value2)
|
||||
await delegated.write_value(Mock(), delegated_value2)
|
||||
assert delegated.value == bytes(reversed(delegated_value2))
|
||||
|
||||
# Packed adapter with single element format
|
||||
@@ -355,10 +355,10 @@ async def test_CharacteristicAdapter() -> None:
|
||||
c.value = packed_value_ref
|
||||
packed = PackedCharacteristicAdapter(c, '>H')
|
||||
|
||||
packed_value_read = await packed.read_value(None)
|
||||
packed_value_read = await packed.read_value(Mock())
|
||||
assert packed_value_read == packed_value_bytes
|
||||
c.value = b''
|
||||
await packed.write_value(None, packed_value_bytes)
|
||||
await packed.write_value(Mock(), packed_value_bytes)
|
||||
assert packed.value == packed_value_ref
|
||||
|
||||
# Packed adapter with multi-element format
|
||||
@@ -368,10 +368,10 @@ async def test_CharacteristicAdapter() -> None:
|
||||
c.value = (v1, v2)
|
||||
packed_multi = PackedCharacteristicAdapter(c, '>HH')
|
||||
|
||||
packed_multi_read_value = await packed_multi.read_value(None)
|
||||
packed_multi_read_value = await packed_multi.read_value(Mock())
|
||||
assert packed_multi_read_value == packed_multi_value_bytes
|
||||
packed_multi.value = b''
|
||||
await packed_multi.write_value(None, packed_multi_value_bytes)
|
||||
await packed_multi.write_value(Mock(), packed_multi_value_bytes)
|
||||
assert packed_multi.value == (v1, v2)
|
||||
|
||||
# Mapped adapter
|
||||
@@ -382,10 +382,10 @@ async def test_CharacteristicAdapter() -> None:
|
||||
c.value = mapped
|
||||
packed_mapped = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
|
||||
|
||||
packed_mapped_read_value = await packed_mapped.read_value(None)
|
||||
packed_mapped_read_value = await packed_mapped.read_value(Mock())
|
||||
assert packed_mapped_read_value == packed_mapped_value_bytes
|
||||
c.value = b''
|
||||
await packed_mapped.write_value(None, packed_mapped_value_bytes)
|
||||
await packed_mapped.write_value(Mock(), packed_mapped_value_bytes)
|
||||
assert packed_mapped.value == mapped
|
||||
|
||||
# UTF-8 adapter
|
||||
@@ -394,10 +394,10 @@ async def test_CharacteristicAdapter() -> None:
|
||||
c.value = string_value
|
||||
string_c = UTF8CharacteristicAdapter(c)
|
||||
|
||||
string_read_value = await string_c.read_value(None)
|
||||
string_read_value = await string_c.read_value(Mock())
|
||||
assert string_read_value == string_value_bytes
|
||||
c.value = b''
|
||||
await string_c.write_value(None, string_value_bytes)
|
||||
await string_c.write_value(Mock(), string_value_bytes)
|
||||
assert string_c.value == string_value
|
||||
|
||||
# Class adapter
|
||||
@@ -419,10 +419,10 @@ async def test_CharacteristicAdapter() -> None:
|
||||
c.value = class_value
|
||||
class_c = SerializableCharacteristicAdapter(c, BlaBla)
|
||||
|
||||
class_read_value = await class_c.read_value(None)
|
||||
class_read_value = await class_c.read_value(Mock())
|
||||
assert class_read_value == class_value_bytes
|
||||
class_c.value = b''
|
||||
await class_c.write_value(None, class_value_bytes)
|
||||
await class_c.write_value(Mock(), class_value_bytes)
|
||||
assert isinstance(class_c.value, BlaBla)
|
||||
assert class_c.value.a == class_value.a
|
||||
assert class_c.value.b == class_value.b
|
||||
@@ -436,10 +436,10 @@ async def test_CharacteristicAdapter() -> None:
|
||||
enum_value_bytes = int(enum_value).to_bytes(3, 'big')
|
||||
c.value = enum_value
|
||||
enum_c = EnumCharacteristicAdapter(c, MyEnum, 3, 'big')
|
||||
enum_read_value = await enum_c.read_value(None)
|
||||
enum_read_value = await enum_c.read_value(Mock())
|
||||
assert enum_read_value == enum_value_bytes
|
||||
enum_c.value = b''
|
||||
await enum_c.write_value(None, enum_value_bytes)
|
||||
await enum_c.write_value(Mock(), enum_value_bytes)
|
||||
assert isinstance(enum_c.value, MyEnum)
|
||||
assert enum_c.value == enum_value
|
||||
|
||||
@@ -1254,7 +1254,7 @@ Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), READ)
|
||||
Service(handle=0x0006, end=0x000D, uuid=UUID-16:1801 (Generic Attribute))
|
||||
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=UUID-16:2A05 (Service Changed), INDICATE)
|
||||
Characteristic(handle=0x0008, end=0x0009, uuid=UUID-16:2A05 (Service Changed), INDICATE)
|
||||
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)
|
||||
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=<dynamic>)
|
||||
CharacteristicDeclaration(handle=0x000A, value_handle=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE)
|
||||
Characteristic(handle=0x000B, end=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE)
|
||||
CharacteristicDeclaration(handle=0x000C, value_handle=0x000D, uuid=UUID-16:2B2A (Database Hash), READ)
|
||||
@@ -1262,7 +1262,7 @@ Characteristic(handle=0x000D, end=0x000D, uuid=UUID-16:2B2A (Database Hash), REA
|
||||
Service(handle=0x000E, end=0x0011, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
|
||||
CharacteristicDeclaration(handle=0x000F, value_handle=0x0010, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
|
||||
Characteristic(handle=0x0010, end=0x0011, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
|
||||
Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
|
||||
Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=<dynamic>)"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user