address PR comments

This commit is contained in:
Gilles Boccon-Gibod
2025-01-21 12:18:06 -05:00
parent 55eb7eb237
commit 931e2de854

View File

@@ -19,7 +19,7 @@ from __future__ import annotations
import asyncio
import logging
import struct
from typing import Iterable, NewType, Optional, Union, Type, TYPE_CHECKING
from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
from typing_extensions import Self
from bumble import core, l2cap
@@ -248,11 +248,11 @@ class DataElement:
return DataElement(DataElement.BOOLEAN, value)
@staticmethod
def sequence(value: list[DataElement]) -> DataElement:
def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@staticmethod
def alternative(value: list[DataElement]) -> DataElement:
def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod
@@ -479,7 +479,9 @@ class ServiceAttribute:
self.value = value
@staticmethod
def list_from_data_elements(elements: list[DataElement]) -> list[ServiceAttribute]:
def list_from_data_elements(
elements: Sequence[DataElement],
) -> list[ServiceAttribute]:
attribute_list = []
for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -492,7 +494,7 @@ class ServiceAttribute:
@staticmethod
def find_attribute_in_list(
attribute_list: list[ServiceAttribute], attribute_id: int
attribute_list: Iterable[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]:
return next(
(
@@ -798,7 +800,7 @@ class Client:
def make_transaction_id(self) -> int:
transaction_id = self.next_transaction_id
self.next_transaction_id = self.next_transaction_id & 0xFFFF
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
return transaction_id
def on_pdu(self, pdu: bytes) -> None:
@@ -893,7 +895,7 @@ class Client:
async def search_attributes(
self,
uuids: list[core.UUID],
uuids: Iterable[core.UUID],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[list[ServiceAttribute]]:
"""
@@ -966,7 +968,7 @@ class Client:
async def get_attributes(
self,
service_record_handle: int,
attribute_ids: list[Union[int, tuple[int, int]]],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[ServiceAttribute]:
"""
Get attributes for a service.
@@ -1154,7 +1156,7 @@ class Server:
@staticmethod
def get_service_attributes(
service: Service, attribute_ids: list[DataElement]
service: Service, attribute_ids: Iterable[DataElement]
) -> DataElement:
attributes = []
for attribute_id in attribute_ids: