Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] 834c8acd85 Bump rand in /rust in the cargo group across 1 directory
Bumps the cargo group with 1 update in the /rust directory: [rand](https://github.com/rust-random/rand).


Updates `rand` from 0.8.5 to 0.9.3
- [Release notes](https://github.com/rust-random/rand/releases)
- [Changelog](https://github.com/rust-random/rand/blob/0.9.3/CHANGELOG.md)
- [Commits](https://github.com/rust-random/rand/compare/0.8.5...0.9.3)

---
updated-dependencies:
- dependency-name: rand
  dependency-version: 0.9.3
  dependency-type: direct:production
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 02:05:17 +00:00
13 changed files with 427 additions and 627 deletions
+2 -4
View File
@@ -42,7 +42,7 @@ from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
# -----------------------------------------------------------------------------
@@ -249,8 +249,6 @@ class ATT_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
if not pdu:
raise InvalidPacketError("Empty ATT PDU")
op_code = pdu[0]
subclass = ATT_PDU.pdu_classes.get(op_code)
@@ -1083,7 +1081,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
else:
value_str = str(self.value)
if value_str:
value_string = f', value={value_str}'
value_string = f', value={self.value.hex()}'
else:
value_string = ''
return (
+77 -140
View File
@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import abc
import asyncio
import enum
import logging
@@ -312,13 +311,6 @@ class MessageAssembler:
def on_pdu(self, pdu: bytes) -> None:
self.packet_count += 1
# Drop empty PDUs sent by remote — accessing pdu[0] below would
# raise IndexError, propagating up to the L2CAP read loop and
# tearing down the channel. Same class as #912 (ATT empty PDU).
if not pdu:
logger.warning('AVDTP message assembler: empty PDU dropped')
return
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
message_type = Message.MessageType(pdu[0] & 3)
@@ -332,23 +324,6 @@ class MessageAssembler:
Protocol.PacketType.SINGLE_PACKET,
Protocol.PacketType.START_PACKET,
):
# Both single and start packets carry the signal identifier in
# pdu[1]; start packets additionally carry the packet count in
# pdu[2]. Guard each access so a malformed remote frame can't
# crash the message assembler.
if len(pdu) < 2:
logger.warning(
'AVDTP %s packet too short (%d bytes); dropped',
packet_type.name,
len(pdu),
)
return
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
logger.warning(
'AVDTP START packet missing signal-packet count; dropped'
)
return
if self.message is not None:
# The previous message has not been terminated
logger.warning(
@@ -1478,23 +1453,8 @@ class Protocol(utils.EventEmitter):
handler = getattr(self, handler_name, None)
if handler:
try:
result = handler(message)
if asyncio.iscoroutine(result):
async def wait_and_send() -> None:
try:
response = await result
if response:
self.send_message(transaction_label, response)
except Exception:
logger.exception(
color("!!! Exception in handler:", "red")
)
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
else:
if result:
self.send_message(transaction_label, result)
response = handler(message)
self.send_message(transaction_label, response)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
else:
@@ -1575,7 +1535,7 @@ class Protocol(utils.EventEmitter):
async def send_command(self, command: Message):
# TODO: support timeouts
# Send the command
transaction_label, transaction_result = await self.start_transaction()
(transaction_label, transaction_result) = await self.start_transaction()
self.send_message(transaction_label, command)
# Wait for the response
@@ -1640,14 +1600,14 @@ class Protocol(utils.EventEmitter):
async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid))
async def on_discover_command(self, command: Discover_Command) -> Message | None:
def on_discover_command(self, command: Discover_Command) -> Message | None:
endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints
]
return Discover_Response(endpoint_infos)
async def on_get_capabilities_command(
def on_get_capabilities_command(
self, command: Get_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1656,7 +1616,7 @@ class Protocol(utils.EventEmitter):
return Get_Capabilities_Response(endpoint.capabilities)
async def on_get_all_capabilities_command(
def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1665,7 +1625,7 @@ class Protocol(utils.EventEmitter):
return Get_All_Capabilities_Response(endpoint.capabilities)
async def on_set_configuration_command(
def on_set_configuration_command(
self, command: Set_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1680,10 +1640,10 @@ class Protocol(utils.EventEmitter):
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
self.streams[command.acp_seid] = stream
result = await stream.on_set_configuration_command(command.capabilities)
result = stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response()
async def on_get_configuration_command(
def on_get_configuration_command(
self, command: Get_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1692,31 +1652,29 @@ class Protocol(utils.EventEmitter):
if endpoint.stream is None:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return await endpoint.stream.on_get_configuration_command()
return endpoint.stream.on_get_configuration_command()
async def on_reconfigure_command(
self, command: Reconfigure_Command
) -> Message | None:
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
result = endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response()
async def on_open_command(self, command: Open_Command) -> Message | None:
def on_open_command(self, command: Open_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await endpoint.stream.on_open_command()
result = endpoint.stream.on_open_command()
return result or Open_Response()
async def on_start_command(self, command: Start_Command) -> Message | None:
def on_start_command(self, command: Start_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1730,12 +1688,12 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := await endpoint.stream.on_start_command()) is not None:
if (result := endpoint.stream.on_start_command()) is not None:
return result
return Start_Response()
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1749,47 +1707,45 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := await endpoint.stream.on_suspend_command()) is not None:
if (result := endpoint.stream.on_suspend_command()) is not None:
return result
return Suspend_Response()
async def on_close_command(self, command: Close_Command) -> Message | None:
def on_close_command(self, command: Close_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Close_Reject(AVDTP_BAD_STATE_ERROR)
result = await endpoint.stream.on_close_command()
result = endpoint.stream.on_close_command()
return result or Close_Response()
async def on_abort_command(self, command: Abort_Command) -> Message | None:
def on_abort_command(self, command: Abort_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None:
return Abort_Response()
await endpoint.stream.on_abort_command()
endpoint.stream.on_abort_command()
return Abort_Response()
async def on_security_control_command(
def on_security_control_command(
self, command: Security_Control_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = await endpoint.on_security_control_command(command.data)
result = endpoint.on_security_control_command(command.data)
return result or Security_Control_Response()
async def on_delayreport_command(
self, command: DelayReport_Command
) -> Message | None:
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = await endpoint.on_delayreport_command(command.delay)
result = endpoint.on_delayreport_command(command.delay)
return result or DelayReport_Response()
@@ -1947,22 +1903,25 @@ class Stream:
await self.rtp_channel.disconnect()
self.rtp_channel = None
# Release the endpoint
self.local_endpoint.in_use = 0
self.change_state(State.IDLE)
async def on_set_configuration_command(
def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.IDLE:
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_set_configuration_command(configuration)
result = self.local_endpoint.on_set_configuration_command(configuration)
if result is not None:
return result
self.change_state(State.CONFIGURED)
return None
async def on_get_configuration_command(self) -> Message | None:
def on_get_configuration_command(self) -> Message | None:
if self.state not in (
State.CONFIGURED,
State.OPEN,
@@ -1970,25 +1929,25 @@ class Stream:
):
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
return await self.local_endpoint.on_get_configuration_command()
return self.local_endpoint.on_get_configuration_command()
async def on_reconfigure_command(
def on_reconfigure_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.OPEN:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_reconfigure_command(configuration)
result = self.local_endpoint.on_reconfigure_command(configuration)
if result is not None:
return result
return None
async def on_open_command(self) -> Message | None:
def on_open_command(self) -> Message | None:
if self.state != State.CONFIGURED:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_open_command()
result = self.local_endpoint.on_open_command()
if result is not None:
return result
@@ -1998,7 +1957,7 @@ class Stream:
self.change_state(State.OPEN)
return None
async def on_start_command(self) -> Message | None:
def on_start_command(self) -> Message | None:
if self.state != State.OPEN:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
@@ -2007,29 +1966,29 @@ class Stream:
logger.warning('received start command before RTP channel establishment')
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_start_command()
result = self.local_endpoint.on_start_command()
if result is not None:
return result
self.change_state(State.STREAMING)
return None
async def on_suspend_command(self) -> Message | None:
def on_suspend_command(self) -> Message | None:
if self.state != State.STREAMING:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_suspend_command()
result = self.local_endpoint.on_suspend_command()
if result is not None:
return result
self.change_state(State.OPEN)
return None
async def on_close_command(self) -> Message | None:
def on_close_command(self) -> Message | None:
if self.state not in (State.OPEN, State.STREAMING):
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_close_command()
result = self.local_endpoint.on_close_command()
if result is not None:
return result
@@ -2044,8 +2003,7 @@ class Stream:
return None
async def on_abort_command(self) -> Message | None:
await self.local_endpoint.on_abort_command()
def on_abort_command(self) -> Message | None:
if self.rtp_channel is None:
# No need to wait
self.change_state(State.IDLE)
@@ -2070,6 +2028,7 @@ class Stream:
def on_l2cap_channel_close(self) -> None:
logger.debug(color('<<< stream channel closed', 'magenta'))
self.local_endpoint.on_rtp_channel_close()
self.local_endpoint.in_use = 0
self.rtp_channel = None
if self.state in (State.CLOSING, State.ABORTING):
@@ -2094,6 +2053,7 @@ class Stream:
self.state = State.IDLE
local_endpoint.stream = self
local_endpoint.in_use = 1
def __str__(self) -> str:
return (
@@ -2103,16 +2063,14 @@ class Stream:
# -----------------------------------------------------------------------------
class StreamEndPoint(abc.ABC):
@dataclass
class StreamEndPoint:
seid: int
media_type: MediaType
tsep: StreamEndPointType
in_use: int
capabilities: Iterable[ServiceCapabilities]
@property
def in_use(self) -> int:
raise NotImplementedError()
# -----------------------------------------------------------------------------
class StreamEndPointProxy:
@@ -2152,30 +2110,14 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
in_use: int,
capabilities: Iterable[ServiceCapabilities],
) -> None:
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self._in_use = in_use
self.capabilities = capabilities
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
@property
def in_use(self) -> int:
return self._in_use
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
StreamEndPointProxy.__init__(self, protocol, seid)
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
stream: Stream | None
@property
def in_use(self) -> int:
if self.stream and self.stream.state != State.IDLE:
return 1
return 0
EVENT_CONFIGURATION = "configuration"
EVENT_OPEN = "open"
EVENT_START = "start"
@@ -2198,13 +2140,8 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
capabilities: Iterable[ServiceCapabilities],
configuration: Iterable[ServiceCapabilities] | None = None,
):
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
utils.EventEmitter.__init__(self)
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.capabilities = capabilities
self.protocol = protocol
self.configuration = configuration if configuration is not None else []
self.stream = None
@@ -2218,13 +2155,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
async def close(self) -> None:
"""[Source Only] Handles when receiving close command."""
async def on_reconfigure_command(
def on_reconfigure_command(
self, command: Iterable[ServiceCapabilities]
) -> Message | None:
del command # unused.
return None
async def on_set_configuration_command(
def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
logger.debug(
@@ -2235,34 +2172,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
self.emit(self.EVENT_CONFIGURATION)
return None
async def on_get_configuration_command(self) -> Message | None:
def on_get_configuration_command(self) -> Message | None:
return Get_Configuration_Response(self.configuration)
async def on_open_command(self) -> Message | None:
def on_open_command(self) -> Message | None:
self.emit(self.EVENT_OPEN)
return None
async def on_start_command(self) -> Message | None:
def on_start_command(self) -> Message | None:
self.emit(self.EVENT_START)
return None
async def on_suspend_command(self) -> Message | None:
def on_suspend_command(self) -> Message | None:
self.emit(self.EVENT_SUSPEND)
return None
async def on_close_command(self) -> Message | None:
def on_close_command(self) -> Message | None:
self.emit(self.EVENT_CLOSE)
return None
async def on_abort_command(self) -> Message | None:
def on_abort_command(self) -> Message | None:
self.emit(self.EVENT_ABORT)
return None
async def on_delayreport_command(self, delay: int) -> Message | None:
def on_delayreport_command(self, delay: int) -> Message | None:
self.emit(self.EVENT_DELAY_REPORT, delay)
return None
async def on_security_control_command(self, data: bytes) -> Message | None:
def on_security_control_command(self, data: bytes) -> Message | None:
self.emit(self.EVENT_SECURITY_CONTROL, data)
return None
@@ -2290,12 +2227,12 @@ class LocalSource(LocalStreamEndPoint):
codec_capabilities,
] + list(other_capabilities)
super().__init__(
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SRC,
capabilities=capabilities,
configuration=capabilities,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
)
self.packet_pump = packet_pump
@@ -2314,13 +2251,13 @@ class LocalSource(LocalStreamEndPoint):
self.emit(self.EVENT_STOP)
@override
async def on_start_command(self) -> Message | None:
await self.start()
def on_start_command(self) -> Message | None:
asyncio.create_task(self.start())
return None
@override
async def on_suspend_command(self) -> Message | None:
await self.stop()
def on_suspend_command(self) -> Message | None:
asyncio.create_task(self.stop())
return None
@@ -2334,11 +2271,11 @@ class LocalSink(LocalStreamEndPoint):
codec_capabilities,
]
super().__init__(
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SNK,
capabilities=capabilities,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
)
def on_rtp_channel_open(self) -> None:
+11 -9
View File
@@ -2343,9 +2343,6 @@ class Device(utils.CompositeEventEmitter):
_pending_cis: dict[int, tuple[int, int]]
gatt_service: gatt_service.GenericAttributeProfileService | None = None
keystore: KeyStore | None = None
inquiry_response: bytes | None = None
address_resolver: smp.AddressResolver | None = None
connect_own_address_type: hci.OwnAddressType | None = None
EVENT_ADVERTISEMENT = "advertisement"
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
@@ -2464,12 +2461,17 @@ class Device(utils.CompositeEventEmitter):
self.bis_links = {}
self.big_syncs = {}
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.classic_pending_accepts = {
hci.Address.ANY: []
} # Futures, by BD address OR [Futures] for hci.Address.ANY
self._cis_lock = asyncio.Lock()
# Own address type cache
self.connect_own_address_type = None
self.name = config.name
self.public_address = hci.Address.ANY
self.random_address = config.address
@@ -5616,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute[_T],
value: _T | None = None,
attribute: Attribute,
value: Any | None = None,
force: bool = False,
) -> None:
"""
@@ -5636,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers(
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
self, attribute: Attribute, value: Any | None = None, force: bool = False
) -> None:
"""
Send a notification to all the subscribers of an attribute.
@@ -5655,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute[_T],
value: _T | None = None,
attribute: Attribute,
value: Any | None = None,
force: bool = False,
):
"""
@@ -5677,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers(
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
self, attribute: Attribute, value: Any | None = None, force: bool = False
):
"""
Send an indication to all the subscribers of an attribute.
+22 -24
View File
@@ -67,8 +67,6 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# Helpers
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
@@ -371,8 +369,8 @@ class Server(utils.EventEmitter):
async def notify_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -392,8 +390,8 @@ class Server(utils.EventEmitter):
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -413,19 +411,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value_as_bytes = (
value = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value_as_bytes
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
@@ -433,8 +431,8 @@ class Server(utils.EventEmitter):
async def indicate_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -454,8 +452,8 @@ class Server(utils.EventEmitter):
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -475,19 +473,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value_as_bytes = (
value = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value_as_bytes
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
@@ -512,8 +510,8 @@ class Server(utils.EventEmitter):
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
# Get all the bearers for which there's at least one subscription
@@ -539,8 +537,8 @@ class Server(utils.EventEmitter):
async def notify_subscribers(
self,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -549,8 +547,8 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
+1 -12
View File
@@ -68,8 +68,6 @@ class HfpProtocolError(ProtocolError):
# -----------------------------------------------------------------------------
class HfpProtocol:
MAX_BUFFER_SIZE: ClassVar[int] = 65536
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
@@ -86,19 +84,10 @@ class HfpProtocol:
def feed(self, data: bytes | str) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8', errors='replace')
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
# Drop incoming data if it would overflow the buffer; keep existing
# partial packet state intact so a future clean packet can still parse.
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
logger.warning(
'HFP buffer overflow (>%d bytes), dropping incoming data',
self.MAX_BUFFER_SIZE,
)
return
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
+4 -2
View File
@@ -692,8 +692,10 @@ class Host(utils.EventEmitter):
finally:
self.pending_command = None
self.pending_response = None
if response is None or (
response.num_hci_command_packets and self.command_semaphore.locked()
if (
response is not None
and response.num_hci_command_packets
and self.command_semaphore.locked()
):
self.command_semaphore.release()
+264 -324
View File
@@ -44,12 +44,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# SDP data elements are nested (SEQUENCE, ALTERNATIVE). Cap parse recursion to
# prevent a malicious peer from crashing the process via a deeply nested PDU.
# 32 levels is well beyond anything a legitimate service record uses.
_MAX_DATA_ELEMENT_NESTING = 32
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -154,6 +148,32 @@ class DataElement:
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type
value: Any
value_size: int | None = None
@@ -170,354 +190,279 @@ class DataElement:
'integer types must have a value size specified'
)
@classmethod
def nil(cls) -> DataElement:
return cls(cls.NIL, None)
@staticmethod
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@classmethod
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@classmethod
def unsigned_integer_8(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@classmethod
def unsigned_integer_16(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@classmethod
def unsigned_integer_32(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@classmethod
def signed_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@classmethod
def signed_integer_8(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@classmethod
def signed_integer_16(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@classmethod
def signed_integer_32(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@classmethod
def uuid(cls, value: core.UUID) -> DataElement:
return cls(cls.UUID, value)
@staticmethod
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@classmethod
def text_string(cls, value: bytes) -> DataElement:
return cls(cls.TEXT_STRING, value)
@staticmethod
def text_string(value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@classmethod
def boolean(cls, value: bool) -> DataElement:
return cls(cls.BOOLEAN, value)
@staticmethod
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@classmethod
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.SEQUENCE, value)
@staticmethod
def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@classmethod
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.ALTERNATIVE, value)
@staticmethod
def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@classmethod
def url(cls, value: str) -> DataElement:
return cls(cls.URL, value)
@staticmethod
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@classmethod
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return data[offset]
case 2:
return struct.unpack_from('>H', data, offset)[0]
case 4:
return struct.unpack_from('>I', data, offset)[0]
case 8:
return struct.unpack_from('>Q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
@classmethod
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return struct.unpack_from('b', data, offset)[0]
case 2:
return struct.unpack_from('>h', data, offset)[0]
case 4:
return struct.unpack_from('>i', data, offset)[0]
case 8:
return struct.unpack_from('>q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 2:
return struct.unpack('>H', data)[0]
@classmethod
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
parser = DataElementParser(data, offset)
element = parser.parse_next()
return parser.offset, element
if len(data) == 4:
return struct.unpack('>I', data)[0]
@classmethod
def from_bytes(cls, data: bytes) -> DataElement:
return DataElementParser(data).parse_next()
if len(data) == 8:
return struct.unpack('>Q', data)[0]
def __bytes__(self) -> bytes:
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
elements = []
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
return elements
@staticmethod
def parse_from_bytes(data, offset):
element = DataElement.from_bytes(data[offset:])
return offset + len(bytes(element)), element
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
elif size_index == 1:
value_size = 2
elif size_index == 2:
value_size = 4
elif size_index == 3:
value_size = 8
elif size_index == 4:
value_size = 16
elif size_index == 5:
value_size = data[1]
value_offset = 1
elif size_index == 6:
value_size = struct.unpack('>H', data[1:3])[0]
value_offset = 2
else: # size_index == 7
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
# Return early if we have a cache
if self._bytes:
return self._bytes
match self.type:
case DataElement.NIL:
data = b''
case DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.type == DataElement.NIL:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
match self.value_size:
case 1:
data = struct.pack('B', self.value)
case 2:
data = struct.pack('>H', self.value)
case 4:
data = struct.pack('>I', self.value)
case 8:
data = struct.pack('>Q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.SIGNED_INTEGER:
match self.value_size:
case 1:
data = struct.pack('b', self.value)
case 2:
data = struct.pack('>h', self.value)
case 4:
data = struct.pack('>i', self.value)
case 8:
data = struct.pack('>q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.UUID:
data = bytes(self.value)[::-1]
case DataElement.URL:
data = self.value.encode('utf8')
case DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
case _:
data = self.value
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
elif self.value_size == 4:
data = struct.pack('>I', self.value)
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
elif self.value_size == 2:
data = struct.pack('>h', self.value)
elif self.value_size == 4:
data = struct.pack('>i', self.value)
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
size = len(data)
size_bytes = b''
match self.type:
case DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
if self.type == DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
size_index = 0
case (
DataElement.UNSIGNED_INTEGER
| DataElement.SIGNED_INTEGER
| DataElement.UUID
):
if size <= 1:
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
case (
DataElement.TEXT_STRING
| DataElement.SEQUENCE
| DataElement.ALTERNATIVE
| DataElement.URL
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
case DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
case unsupported_type:
raise core.InvalidPacketError(
f"internal error - {unsupported_type} not supported"
)
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
else:
raise RuntimeError("internal error - self.type not supported")
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
def to_string(self, pretty=False, indentation=0):
prefix = ' ' * indentation
type_name = self.type.name
match self.type:
case DataElement.NIL:
value_string = ''
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
value_string = f'{self.value}#{self.value_size}'
case _:
if isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
return f'{prefix}{type_name}({value_string})'
def __str__(self) -> str:
def __str__(self):
return self.to_string()
class DataElementParser:
def __init__(
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
) -> None:
self.data = data
self.offset = offset
self.depth = 0
self.max_depth = max_depth
def parse_next(self) -> DataElement:
if self.offset >= len(self.data):
raise core.InvalidStateError(
f"offset {self.offset} exceeds len(data) {len(self.data)}"
)
start_offset = self.offset
element_type = DataElement.Type(self.data[self.offset] >> 3)
size_index = self.data[self.offset] & 7
self.offset += 1
value_size: int
match size_index:
case 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
case 1:
value_size = 2
case 2:
value_size = 4
case 3:
value_size = 8
case 4:
value_size = 16
case 5:
value_size = self.data[self.offset]
self.offset += 1
case 6:
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
self.offset += 2
case 7:
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
self.offset += 4
case _:
raise core.UnreachableError()
value_start = self.offset
value_end = self.offset + value_size
match element_type:
case DataElement.NIL:
result = DataElement(DataElement.NIL, None)
case DataElement.UNSIGNED_INTEGER:
result = DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.SIGNED_INTEGER:
result = DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.UUID:
result = DataElement(
DataElement.UUID,
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
)
case DataElement.TEXT_STRING:
result = DataElement(
DataElement.TEXT_STRING, self.data[value_start:value_end]
)
case DataElement.BOOLEAN:
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
self.offset = value_start
result = DataElement(
element_type,
self._list_from_bytes(value_end),
)
if self.offset != value_end:
logger.warning(
"Expect parsing until offset %d, but ends at %d",
value_end,
self.offset,
)
case DataElement.URL:
result = DataElement(
DataElement.URL, self.data[value_start:value_end].decode('utf8')
)
case other_type:
result = DataElement(other_type, self.data[value_start:value_end])
self.offset = value_end
result._bytes = self.data[start_offset:value_end]
return result
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
if self.depth >= self.max_depth:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
)
self.depth += 1
elements = []
while self.offset < end_offset:
elements.append(self.parse_next())
self.depth -= 1
return elements
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute:
@@ -649,10 +594,7 @@ class SDP_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
if len(pdu) != 5 + parameters_length:
logger.warning("Expect %d bytes, got %d", 5 + parameters_length, len(pdu))
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
subclass = cls.subclasses.get(pdu_id)
if not (subclass := cls.subclasses.get(pdu_id)):
@@ -674,11 +616,9 @@ class SDP_PDU:
def __bytes__(self):
if self._payload is None:
parameters = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
self._payload = (
struct.pack('>BHH', self.pdu_id, self.transaction_id, len(parameters))
+ parameters
)
self._payload = struct.pack(
'>BHH', self.pdu_id, self.transaction_id, 0
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@property
-3
View File
@@ -36,7 +36,6 @@ from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
InvalidPacketError,
PhysicalTransport,
ProtocolError,
)
@@ -216,8 +215,6 @@ class SMP_Command:
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
if not pdu:
raise InvalidPacketError("Empty SMP PDU")
code = CommandCode(pdu[0])
subclass = SMP_Command.smp_classes.get(code)
+42 -10
View File
@@ -657,6 +657,18 @@ dependencies = [
"wasi",
]
[[package]]
name = "getrandom"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasip2",
]
[[package]]
name = "gimli"
version = "0.28.0"
@@ -1402,21 +1414,26 @@ dependencies = [
]
[[package]]
name = "rand"
version = "0.8.5"
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
@@ -1424,11 +1441,11 @@ dependencies = [
[[package]]
name = "rand_core"
version = "0.6.4"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c"
dependencies = [
"getrandom",
"getrandom 0.3.4",
]
[[package]]
@@ -1455,7 +1472,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom",
"getrandom 0.2.10",
"redox_syscall 0.2.16",
"thiserror",
]
@@ -2028,6 +2045,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasip2"
version = "1.0.2+wasi-0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wasm-bindgen"
version = "0.2.87"
@@ -2283,3 +2309,9 @@ dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
+1 -1
View File
@@ -57,7 +57,7 @@ anyhow = "1.0.71"
pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] }
rusb = "0.9.2"
rand = "0.8.5"
rand = "0.9.3"
clap = { version = "4.3.3", features = ["derive"] }
owo-colors = "3.5.0"
log = "0.4.19"
-25
View File
@@ -120,31 +120,6 @@ def test_messages(message: avdtp.Message):
assert message.payload == parsed.payload
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'pdu',
(
b'', # empty PDU — would IndexError on pdu[0]
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
),
)
def test_message_assembler_truncated_pdu(pdu: bytes):
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
same DoS class as #912 (ATT empty PDU). The assembler is required to
log + drop and stay alive so the L2CAP channel survives."""
completed = []
def callback(transaction_label, message):
completed.append((transaction_label, message))
assembler = avdtp.MessageAssembler(callback)
# Must not raise; nothing should be delivered to callback either.
assembler.on_pdu(pdu)
assert not completed
# -----------------------------------------------------------------------------
def test_rtp():
packet = bytes.fromhex(
+3 -21
View File
@@ -171,15 +171,14 @@ class Source:
class Sink:
response: HCI_Event | None
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event | None) -> None:
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
if self.response is not None:
self.source.sink.on_packet(bytes(self.response))
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
@@ -229,23 +228,6 @@ async def test_send_sync_command() -> None:
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_sync_command_timeout() -> None:
source = Source()
sink = Sink(source, None)
host = Host(source, sink)
host.ready = True
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
# The sending semaphore should have been released, so this should not block
# indefinitely
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()
-52
View File
@@ -18,11 +18,9 @@
import asyncio
import logging
import os
import re
import pytest
from bumble import sdp
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
@@ -208,16 +206,6 @@ def sdp_records(record_count=1):
}
# -----------------------------------------------------------------------------
def test_pdu_parameter_length(caplog) -> None:
caplog.set_level(logging.WARNING)
pdu = sdp.SDP_ErrorResponse(
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
)
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search():
@@ -440,43 +428,3 @@ async def run():
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
# -----------------------------------------------------------------------------
def test_nested_sequence_recursion_guard():
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
the parser with RecursionError. Instead a ValueError is raised once the
configured nesting limit is exceeded.
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
without a depth limit. A malicious SDP peer could craft a PDU exceeding
Pythons default recursion limit (~1000 frames) and crash the host.
"""
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
inner = b"\x35\x00" # empty SEQUENCE terminator
for _ in range(1500):
size = len(inner)
if size >= 65535:
break
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
with pytest.raises(ValueError, match="nesting exceeds max depth"):
DataElement.from_bytes(inner)
def test_nested_sequence_within_limit_still_works():
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
leaf = DataElement.unsigned_integer(1, value_size=2)
payload = leaf
for _ in range(16): # under the 32-depth limit
payload = DataElement.sequence([payload])
raw = bytes(payload)
parsed = DataElement.from_bytes(raw)
# Walk back down to confirm structural integrity preserved
cur = parsed
for _ in range(16):
assert cur.type == DataElement.SEQUENCE
cur = cur.value[0]
assert cur.type == DataElement.UNSIGNED_INTEGER
assert cur.value == 1