mirror of
https://github.com/google/bumble.git
synced 2026-05-07 03:48:01 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8988a85245 | ||
|
|
0813da2278 | ||
|
|
a1ff183d44 | ||
|
|
7adf44eddf | ||
|
|
05accbf805 | ||
|
|
80f54f2a09 | ||
|
|
07b5e33e09 | ||
|
|
b874e26a4f | ||
|
|
baa5257780 | ||
|
|
a91ea9110c | ||
|
|
1686c5b11b | ||
|
|
d9481992bb | ||
|
|
16d0ed56cf | ||
|
|
c55eb156b8 | ||
|
|
8614881fb3 | ||
|
|
27d02ef18d | ||
|
|
c0725e2a4a | ||
|
|
bf0784dde4 | ||
|
|
444f43f6a3 | ||
|
|
2420c47cf1 | ||
|
|
0a78e7506b | ||
|
|
f7cc6f6657 |
@@ -42,7 +42,7 @@ from typing_extensions import TypeIs
|
||||
|
||||
from bumble import hci, l2cap, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID, InvalidOperationError, ProtocolError
|
||||
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
|
||||
from bumble.hci import HCI_Object
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -249,6 +249,8 @@ class ATT_PDU:
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
|
||||
if not pdu:
|
||||
raise InvalidPacketError("Empty ATT PDU")
|
||||
op_code = pdu[0]
|
||||
|
||||
subclass = ATT_PDU.pdu_classes.get(op_code)
|
||||
@@ -1081,7 +1083,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
else:
|
||||
value_str = str(self.value)
|
||||
if value_str:
|
||||
value_string = f', value={self.value.hex()}'
|
||||
value_string = f', value={value_str}'
|
||||
else:
|
||||
value_string = ''
|
||||
return (
|
||||
|
||||
@@ -311,6 +311,13 @@ class MessageAssembler:
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.packet_count += 1
|
||||
|
||||
# Drop empty PDUs sent by remote — accessing pdu[0] below would
|
||||
# raise IndexError, propagating up to the L2CAP read loop and
|
||||
# tearing down the channel. Same class as #912 (ATT empty PDU).
|
||||
if not pdu:
|
||||
logger.warning('AVDTP message assembler: empty PDU dropped')
|
||||
return
|
||||
|
||||
transaction_label = pdu[0] >> 4
|
||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||
message_type = Message.MessageType(pdu[0] & 3)
|
||||
@@ -324,6 +331,23 @@ class MessageAssembler:
|
||||
Protocol.PacketType.SINGLE_PACKET,
|
||||
Protocol.PacketType.START_PACKET,
|
||||
):
|
||||
# Both single and start packets carry the signal identifier in
|
||||
# pdu[1]; start packets additionally carry the packet count in
|
||||
# pdu[2]. Guard each access so a malformed remote frame can't
|
||||
# crash the message assembler.
|
||||
if len(pdu) < 2:
|
||||
logger.warning(
|
||||
'AVDTP %s packet too short (%d bytes); dropped',
|
||||
packet_type.name,
|
||||
len(pdu),
|
||||
)
|
||||
return
|
||||
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
|
||||
logger.warning(
|
||||
'AVDTP START packet missing signal-packet count; dropped'
|
||||
)
|
||||
return
|
||||
|
||||
if self.message is not None:
|
||||
# The previous message has not been terminated
|
||||
logger.warning(
|
||||
|
||||
@@ -5618,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Any | None = None,
|
||||
attribute: Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -5638,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||
|
||||
async def notify_subscribers(
|
||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
||||
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Send a notification to all the subscribers of an attribute.
|
||||
@@ -5657,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Any | None = None,
|
||||
attribute: Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -5679,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(
|
||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
||||
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||
):
|
||||
"""
|
||||
Send an indication to all the subscribers of an attribute.
|
||||
|
||||
@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def _bearer_id(bearer: att.Bearer) -> str:
|
||||
if att.is_enhanced_bearer(bearer):
|
||||
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if att.is_enhanced_bearer(bearer) or force:
|
||||
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
|
||||
async def _notify_single_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
|
||||
return
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
value_as_bytes = (
|
||||
await attribute.read_value(bearer)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
|
||||
# Truncate if needed
|
||||
if len(value) > bearer.att_mtu - 3:
|
||||
value = value[: bearer.att_mtu - 3]
|
||||
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||
|
||||
# Notify
|
||||
notification = att.ATT_Handle_Value_Notification(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||
)
|
||||
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
||||
self.send_gatt_pdu(bearer, bytes(notification))
|
||||
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if att.is_enhanced_bearer(bearer) or force:
|
||||
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
|
||||
async def _indicate_single_bearer(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
|
||||
return
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
value_as_bytes = (
|
||||
await attribute.read_value(bearer)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
|
||||
# Truncate if needed
|
||||
if len(value) > bearer.att_mtu - 3:
|
||||
value = value[: bearer.att_mtu - 3]
|
||||
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||
|
||||
# Indicate
|
||||
indication = att.ATT_Handle_Value_Indication(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||
)
|
||||
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
||||
|
||||
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
|
||||
async def _notify_or_indicate_subscribers(
|
||||
self,
|
||||
indicate: bool,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Get all the bearers for which there's at least one subscription
|
||||
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
|
||||
|
||||
async def notify_subscribers(
|
||||
self,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self._notify_or_indicate_subscribers(
|
||||
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
|
||||
|
||||
async def indicate_subscribers(
|
||||
self,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
@@ -68,6 +68,8 @@ class HfpProtocolError(ProtocolError):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HfpProtocol:
|
||||
MAX_BUFFER_SIZE: ClassVar[int] = 65536
|
||||
|
||||
dlc: rfcomm.DLC
|
||||
buffer: str
|
||||
lines: collections.deque
|
||||
@@ -84,10 +86,19 @@ class HfpProtocol:
|
||||
def feed(self, data: bytes | str) -> None:
|
||||
# Convert the data to a string if needed
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode('utf-8')
|
||||
data = data.decode('utf-8', errors='replace')
|
||||
|
||||
logger.debug(f'<<< Data received: {data}')
|
||||
|
||||
# Drop incoming data if it would overflow the buffer; keep existing
|
||||
# partial packet state intact so a future clean packet can still parse.
|
||||
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
|
||||
logger.warning(
|
||||
'HFP buffer overflow (>%d bytes), dropping incoming data',
|
||||
self.MAX_BUFFER_SIZE,
|
||||
)
|
||||
return
|
||||
|
||||
# Add to the buffer and look for lines
|
||||
self.buffer += data
|
||||
while (separator := self.buffer.find('\r')) >= 0:
|
||||
|
||||
@@ -692,10 +692,8 @@ class Host(utils.EventEmitter):
|
||||
finally:
|
||||
self.pending_command = None
|
||||
self.pending_response = None
|
||||
if (
|
||||
response is not None
|
||||
and response.num_hci_command_packets
|
||||
and self.command_semaphore.locked()
|
||||
if response is None or (
|
||||
response.num_hci_command_packets and self.command_semaphore.locked()
|
||||
):
|
||||
self.command_semaphore.release()
|
||||
|
||||
|
||||
588
bumble/sdp.py
588
bumble/sdp.py
@@ -44,6 +44,12 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SDP data elements are nested (SEQUENCE, ALTERNATIVE). Cap parse recursion to
|
||||
# prevent a malicious peer from crashing the process via a deeply nested PDU.
|
||||
# 32 levels is well beyond anything a legitimate service record uses.
|
||||
_MAX_DATA_ELEMENT_NESTING = 32
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -148,32 +154,6 @@ class DataElement:
|
||||
ALTERNATIVE = Type.ALTERNATIVE
|
||||
URL = Type.URL
|
||||
|
||||
TYPE_CONSTRUCTORS = {
|
||||
NIL: lambda x: DataElement(DataElement.NIL, None),
|
||||
UNSIGNED_INTEGER: lambda x, y: DataElement(
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.unsigned_integer_from_bytes(x),
|
||||
value_size=y,
|
||||
),
|
||||
SIGNED_INTEGER: lambda x, y: DataElement(
|
||||
DataElement.SIGNED_INTEGER,
|
||||
DataElement.signed_integer_from_bytes(x),
|
||||
value_size=y,
|
||||
),
|
||||
UUID: lambda x: DataElement(
|
||||
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
|
||||
),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
|
||||
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
|
||||
SEQUENCE: lambda x: DataElement(
|
||||
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
|
||||
),
|
||||
ALTERNATIVE: lambda x: DataElement(
|
||||
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
|
||||
),
|
||||
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
|
||||
}
|
||||
|
||||
type: Type
|
||||
value: Any
|
||||
value_size: int | None = None
|
||||
@@ -190,279 +170,354 @@ class DataElement:
|
||||
'integer types must have a value size specified'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def nil() -> DataElement:
|
||||
return DataElement(DataElement.NIL, None)
|
||||
@classmethod
|
||||
def nil(cls) -> DataElement:
|
||||
return cls(cls.NIL, None)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer(value: int, value_size: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
|
||||
@classmethod
|
||||
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_8(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
|
||||
@classmethod
|
||||
def unsigned_integer_8(cls, value: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_16(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
|
||||
@classmethod
|
||||
def unsigned_integer_16(cls, value: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_32(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
|
||||
@classmethod
|
||||
def unsigned_integer_32(cls, value: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer(value: int, value_size: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
|
||||
@classmethod
|
||||
def signed_integer(cls, value: int, value_size: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_8(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
|
||||
@classmethod
|
||||
def signed_integer_8(cls, value: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size=1)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_16(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
|
||||
@classmethod
|
||||
def signed_integer_16(cls, value: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size=2)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_32(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
|
||||
@classmethod
|
||||
def signed_integer_32(cls, value: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size=4)
|
||||
|
||||
@staticmethod
|
||||
def uuid(value: core.UUID) -> DataElement:
|
||||
return DataElement(DataElement.UUID, value)
|
||||
@classmethod
|
||||
def uuid(cls, value: core.UUID) -> DataElement:
|
||||
return cls(cls.UUID, value)
|
||||
|
||||
@staticmethod
|
||||
def text_string(value: bytes) -> DataElement:
|
||||
return DataElement(DataElement.TEXT_STRING, value)
|
||||
@classmethod
|
||||
def text_string(cls, value: bytes) -> DataElement:
|
||||
return cls(cls.TEXT_STRING, value)
|
||||
|
||||
@staticmethod
|
||||
def boolean(value: bool) -> DataElement:
|
||||
return DataElement(DataElement.BOOLEAN, value)
|
||||
@classmethod
|
||||
def boolean(cls, value: bool) -> DataElement:
|
||||
return cls(cls.BOOLEAN, value)
|
||||
|
||||
@staticmethod
|
||||
def sequence(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.SEQUENCE, value)
|
||||
@classmethod
|
||||
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
|
||||
return cls(cls.SEQUENCE, value)
|
||||
|
||||
@staticmethod
|
||||
def alternative(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.ALTERNATIVE, value)
|
||||
@classmethod
|
||||
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
|
||||
return cls(cls.ALTERNATIVE, value)
|
||||
|
||||
@staticmethod
|
||||
def url(value: str) -> DataElement:
|
||||
return DataElement(DataElement.URL, value)
|
||||
@classmethod
|
||||
def url(cls, value: str) -> DataElement:
|
||||
return cls(cls.URL, value)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_from_bytes(data):
|
||||
if len(data) == 1:
|
||||
return data[0]
|
||||
@classmethod
|
||||
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
|
||||
match length:
|
||||
case 1:
|
||||
return data[offset]
|
||||
case 2:
|
||||
return struct.unpack_from('>H', data, offset)[0]
|
||||
case 4:
|
||||
return struct.unpack_from('>I', data, offset)[0]
|
||||
case 8:
|
||||
return struct.unpack_from('>Q', data, offset)[0]
|
||||
case invalid_length:
|
||||
raise InvalidPacketError(f'invalid integer length {invalid_length}')
|
||||
|
||||
if len(data) == 2:
|
||||
return struct.unpack('>H', data)[0]
|
||||
@classmethod
|
||||
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
|
||||
match length:
|
||||
case 1:
|
||||
return struct.unpack_from('b', data, offset)[0]
|
||||
case 2:
|
||||
return struct.unpack_from('>h', data, offset)[0]
|
||||
case 4:
|
||||
return struct.unpack_from('>i', data, offset)[0]
|
||||
case 8:
|
||||
return struct.unpack_from('>q', data, offset)[0]
|
||||
case invalid_length:
|
||||
raise InvalidPacketError(f'invalid integer length {invalid_length}')
|
||||
|
||||
if len(data) == 4:
|
||||
return struct.unpack('>I', data)[0]
|
||||
@classmethod
|
||||
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
|
||||
parser = DataElementParser(data, offset)
|
||||
element = parser.parse_next()
|
||||
return parser.offset, element
|
||||
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>Q', data)[0]
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> DataElement:
|
||||
return DataElementParser(data).parse_next()
|
||||
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_from_bytes(data):
|
||||
if len(data) == 1:
|
||||
return struct.unpack('b', data)[0]
|
||||
|
||||
if len(data) == 2:
|
||||
return struct.unpack('>h', data)[0]
|
||||
|
||||
if len(data) == 4:
|
||||
return struct.unpack('>i', data)[0]
|
||||
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>q', data)[0]
|
||||
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def list_from_bytes(data):
|
||||
elements = []
|
||||
while data:
|
||||
element = DataElement.from_bytes(data)
|
||||
elements.append(element)
|
||||
data = data[len(bytes(element)) :]
|
||||
return elements
|
||||
|
||||
@staticmethod
|
||||
def parse_from_bytes(data, offset):
|
||||
element = DataElement.from_bytes(data[offset:])
|
||||
return offset + len(bytes(element)), element
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
element_type = data[0] >> 3
|
||||
size_index = data[0] & 7
|
||||
value_offset = 0
|
||||
if size_index == 0:
|
||||
if element_type == DataElement.NIL:
|
||||
value_size = 0
|
||||
else:
|
||||
value_size = 1
|
||||
elif size_index == 1:
|
||||
value_size = 2
|
||||
elif size_index == 2:
|
||||
value_size = 4
|
||||
elif size_index == 3:
|
||||
value_size = 8
|
||||
elif size_index == 4:
|
||||
value_size = 16
|
||||
elif size_index == 5:
|
||||
value_size = data[1]
|
||||
value_offset = 1
|
||||
elif size_index == 6:
|
||||
value_size = struct.unpack('>H', data[1:3])[0]
|
||||
value_offset = 2
|
||||
else: # size_index == 7
|
||||
value_size = struct.unpack('>I', data[1:5])[0]
|
||||
value_offset = 4
|
||||
|
||||
value_data = data[1 + value_offset : 1 + value_offset + value_size]
|
||||
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
|
||||
if constructor:
|
||||
if element_type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.SIGNED_INTEGER,
|
||||
):
|
||||
result = constructor(value_data, value_size)
|
||||
else:
|
||||
result = constructor(value_data)
|
||||
else:
|
||||
result = DataElement(element_type, value_data)
|
||||
result._bytes = data[
|
||||
: 1 + value_offset + value_size
|
||||
] # Keep a copy so we can re-serialize to an exact replica
|
||||
return result
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
# Return early if we have a cache
|
||||
if self._bytes:
|
||||
return self._bytes
|
||||
|
||||
if self.type == DataElement.NIL:
|
||||
data = b''
|
||||
elif self.type == DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
match self.type:
|
||||
case DataElement.NIL:
|
||||
data = b''
|
||||
case DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('B', self.value)
|
||||
elif self.value_size == 2:
|
||||
data = struct.pack('>H', self.value)
|
||||
elif self.value_size == 4:
|
||||
data = struct.pack('>I', self.value)
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.SIGNED_INTEGER:
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('b', self.value)
|
||||
elif self.value_size == 2:
|
||||
data = struct.pack('>h', self.value)
|
||||
elif self.value_size == 4:
|
||||
data = struct.pack('>i', self.value)
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.UUID:
|
||||
data = bytes(reversed(bytes(self.value)))
|
||||
elif self.type == DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
|
||||
data = b''.join([bytes(element) for element in self.value])
|
||||
else:
|
||||
data = self.value
|
||||
match self.value_size:
|
||||
case 1:
|
||||
data = struct.pack('B', self.value)
|
||||
case 2:
|
||||
data = struct.pack('>H', self.value)
|
||||
case 4:
|
||||
data = struct.pack('>I', self.value)
|
||||
case 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
case invalid_length:
|
||||
raise InvalidArgumentError(
|
||||
f'invalid value_size of {invalid_length}'
|
||||
)
|
||||
case DataElement.SIGNED_INTEGER:
|
||||
match self.value_size:
|
||||
case 1:
|
||||
data = struct.pack('b', self.value)
|
||||
case 2:
|
||||
data = struct.pack('>h', self.value)
|
||||
case 4:
|
||||
data = struct.pack('>i', self.value)
|
||||
case 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
case invalid_length:
|
||||
raise InvalidArgumentError(
|
||||
f'invalid value_size of {invalid_length}'
|
||||
)
|
||||
case DataElement.UUID:
|
||||
data = bytes(self.value)[::-1]
|
||||
case DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
case DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
|
||||
data = b''.join([bytes(element) for element in self.value])
|
||||
case _:
|
||||
data = self.value
|
||||
|
||||
size = len(data)
|
||||
size_bytes = b''
|
||||
if self.type == DataElement.NIL:
|
||||
if size != 0:
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif self.type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.SIGNED_INTEGER,
|
||||
DataElement.UUID,
|
||||
):
|
||||
if size <= 1:
|
||||
match self.type:
|
||||
case DataElement.NIL:
|
||||
if size != 0:
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif size == 2:
|
||||
size_index = 1
|
||||
elif size == 4:
|
||||
size_index = 2
|
||||
elif size == 8:
|
||||
size_index = 3
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type in (
|
||||
DataElement.TEXT_STRING,
|
||||
DataElement.SEQUENCE,
|
||||
DataElement.ALTERNATIVE,
|
||||
DataElement.URL,
|
||||
):
|
||||
if size <= 0xFF:
|
||||
size_index = 5
|
||||
size_bytes = bytes([size])
|
||||
elif size <= 0xFFFF:
|
||||
size_index = 6
|
||||
size_bytes = struct.pack('>H', size)
|
||||
elif size <= 0xFFFFFFFF:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
else:
|
||||
raise RuntimeError("internal error - self.type not supported")
|
||||
case (
|
||||
DataElement.UNSIGNED_INTEGER
|
||||
| DataElement.SIGNED_INTEGER
|
||||
| DataElement.UUID
|
||||
):
|
||||
if size <= 1:
|
||||
size_index = 0
|
||||
elif size == 2:
|
||||
size_index = 1
|
||||
elif size == 4:
|
||||
size_index = 2
|
||||
elif size == 8:
|
||||
size_index = 3
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
case (
|
||||
DataElement.TEXT_STRING
|
||||
| DataElement.SEQUENCE
|
||||
| DataElement.ALTERNATIVE
|
||||
| DataElement.URL
|
||||
):
|
||||
if size <= 0xFF:
|
||||
size_index = 5
|
||||
size_bytes = bytes([size])
|
||||
elif size <= 0xFFFF:
|
||||
size_index = 6
|
||||
size_bytes = struct.pack('>H', size)
|
||||
elif size <= 0xFFFFFFFF:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
case DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
case unsupported_type:
|
||||
raise core.InvalidPacketError(
|
||||
f"internal error - {unsupported_type} not supported"
|
||||
)
|
||||
|
||||
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||
return self._bytes
|
||||
|
||||
def to_string(self, pretty=False, indentation=0):
|
||||
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
|
||||
prefix = ' ' * indentation
|
||||
type_name = self.type.name
|
||||
if self.type == DataElement.NIL:
|
||||
value_string = ''
|
||||
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
|
||||
container_separator = '\n' if pretty else ''
|
||||
element_separator = '\n' if pretty else ','
|
||||
elements = [
|
||||
element.to_string(pretty, indentation + 1 if pretty else 0)
|
||||
for element in self.value
|
||||
]
|
||||
value_string = (
|
||||
f'[{container_separator}'
|
||||
f'{element_separator.join(elements)}'
|
||||
f'{container_separator}{prefix}]'
|
||||
)
|
||||
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
||||
value_string = f'{self.value}#{self.value_size}'
|
||||
elif isinstance(self.value, DataElement):
|
||||
value_string = self.value.to_string(pretty, indentation)
|
||||
else:
|
||||
value_string = str(self.value)
|
||||
match self.type:
|
||||
case DataElement.NIL:
|
||||
value_string = ''
|
||||
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
|
||||
container_separator = '\n' if pretty else ''
|
||||
element_separator = '\n' if pretty else ','
|
||||
elements = [
|
||||
element.to_string(pretty, indentation + 1 if pretty else 0)
|
||||
for element in self.value
|
||||
]
|
||||
value_string = (
|
||||
f'[{container_separator}'
|
||||
f'{element_separator.join(elements)}'
|
||||
f'{container_separator}{prefix}]'
|
||||
)
|
||||
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
|
||||
value_string = f'{self.value}#{self.value_size}'
|
||||
case _:
|
||||
if isinstance(self.value, DataElement):
|
||||
value_string = self.value.to_string(pretty, indentation)
|
||||
else:
|
||||
value_string = str(self.value)
|
||||
return f'{prefix}{type_name}({value_string})'
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class DataElementParser:
|
||||
def __init__(
|
||||
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
|
||||
) -> None:
|
||||
self.data = data
|
||||
self.offset = offset
|
||||
self.depth = 0
|
||||
self.max_depth = max_depth
|
||||
|
||||
def parse_next(self) -> DataElement:
|
||||
if self.offset >= len(self.data):
|
||||
raise core.InvalidStateError(
|
||||
f"offset {self.offset} exceeds len(data) {len(self.data)}"
|
||||
)
|
||||
start_offset = self.offset
|
||||
element_type = DataElement.Type(self.data[self.offset] >> 3)
|
||||
size_index = self.data[self.offset] & 7
|
||||
self.offset += 1
|
||||
|
||||
value_size: int
|
||||
match size_index:
|
||||
case 0:
|
||||
if element_type == DataElement.NIL:
|
||||
value_size = 0
|
||||
else:
|
||||
value_size = 1
|
||||
case 1:
|
||||
value_size = 2
|
||||
case 2:
|
||||
value_size = 4
|
||||
case 3:
|
||||
value_size = 8
|
||||
case 4:
|
||||
value_size = 16
|
||||
case 5:
|
||||
value_size = self.data[self.offset]
|
||||
self.offset += 1
|
||||
case 6:
|
||||
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
|
||||
self.offset += 2
|
||||
case 7:
|
||||
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
|
||||
self.offset += 4
|
||||
case _:
|
||||
raise core.UnreachableError()
|
||||
|
||||
value_start = self.offset
|
||||
value_end = self.offset + value_size
|
||||
|
||||
match element_type:
|
||||
case DataElement.NIL:
|
||||
result = DataElement(DataElement.NIL, None)
|
||||
case DataElement.UNSIGNED_INTEGER:
|
||||
result = DataElement(
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.unsigned_integer_from_bytes(
|
||||
self.data, value_start, value_size
|
||||
),
|
||||
value_size=value_size,
|
||||
)
|
||||
case DataElement.SIGNED_INTEGER:
|
||||
result = DataElement(
|
||||
DataElement.SIGNED_INTEGER,
|
||||
DataElement.signed_integer_from_bytes(
|
||||
self.data, value_start, value_size
|
||||
),
|
||||
value_size=value_size,
|
||||
)
|
||||
case DataElement.UUID:
|
||||
result = DataElement(
|
||||
DataElement.UUID,
|
||||
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
|
||||
)
|
||||
case DataElement.TEXT_STRING:
|
||||
result = DataElement(
|
||||
DataElement.TEXT_STRING, self.data[value_start:value_end]
|
||||
)
|
||||
case DataElement.BOOLEAN:
|
||||
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
|
||||
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
|
||||
self.offset = value_start
|
||||
result = DataElement(
|
||||
element_type,
|
||||
self._list_from_bytes(value_end),
|
||||
)
|
||||
if self.offset != value_end:
|
||||
logger.warning(
|
||||
"Expect parsing until offset %d, but ends at %d",
|
||||
value_end,
|
||||
self.offset,
|
||||
)
|
||||
case DataElement.URL:
|
||||
result = DataElement(
|
||||
DataElement.URL, self.data[value_start:value_end].decode('utf8')
|
||||
)
|
||||
case other_type:
|
||||
result = DataElement(other_type, self.data[value_start:value_end])
|
||||
|
||||
self.offset = value_end
|
||||
result._bytes = self.data[start_offset:value_end]
|
||||
|
||||
return result
|
||||
|
||||
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
|
||||
if self.depth >= self.max_depth:
|
||||
raise InvalidPacketError(
|
||||
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
|
||||
)
|
||||
self.depth += 1
|
||||
elements = []
|
||||
while self.offset < end_offset:
|
||||
elements.append(self.parse_next())
|
||||
self.depth -= 1
|
||||
return elements
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class ServiceAttribute:
|
||||
@@ -594,7 +649,10 @@ class SDP_PDU:
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
|
||||
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
|
||||
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
|
||||
|
||||
if len(pdu) != 5 + parameters_length:
|
||||
logger.warning("Expect %d bytes, got %d", 5 + parameters_length, len(pdu))
|
||||
|
||||
subclass = cls.subclasses.get(pdu_id)
|
||||
if not (subclass := cls.subclasses.get(pdu_id)):
|
||||
@@ -616,9 +674,11 @@ class SDP_PDU:
|
||||
|
||||
def __bytes__(self):
|
||||
if self._payload is None:
|
||||
self._payload = struct.pack(
|
||||
'>BHH', self.pdu_id, self.transaction_id, 0
|
||||
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
|
||||
parameters = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
|
||||
self._payload = (
|
||||
struct.pack('>BHH', self.pdu_id, self.transaction_id, len(parameters))
|
||||
+ parameters
|
||||
)
|
||||
return self._payload
|
||||
|
||||
@property
|
||||
|
||||
@@ -36,6 +36,7 @@ from bumble.colors import color
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
PhysicalTransport,
|
||||
ProtocolError,
|
||||
)
|
||||
@@ -215,6 +216,8 @@ class SMP_Command:
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, pdu: bytes) -> SMP_Command:
|
||||
if not pdu:
|
||||
raise InvalidPacketError("Empty SMP PDU")
|
||||
code = CommandCode(pdu[0])
|
||||
|
||||
subclass = SMP_Command.smp_classes.get(code)
|
||||
|
||||
@@ -120,6 +120,31 @@ def test_messages(message: avdtp.Message):
|
||||
assert message.payload == parsed.payload
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'pdu',
|
||||
(
|
||||
b'', # empty PDU — would IndexError on pdu[0]
|
||||
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
|
||||
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
|
||||
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
|
||||
),
|
||||
)
|
||||
def test_message_assembler_truncated_pdu(pdu: bytes):
|
||||
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
|
||||
same DoS class as #912 (ATT empty PDU). The assembler is required to
|
||||
log + drop and stay alive so the L2CAP channel survives."""
|
||||
completed = []
|
||||
|
||||
def callback(transaction_label, message):
|
||||
completed.append((transaction_label, message))
|
||||
|
||||
assembler = avdtp.MessageAssembler(callback)
|
||||
# Must not raise; nothing should be delivered to callback either.
|
||||
assembler.on_pdu(pdu)
|
||||
assert not completed
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_rtp():
|
||||
packet = bytes.fromhex(
|
||||
|
||||
@@ -171,14 +171,15 @@ class Source:
|
||||
|
||||
|
||||
class Sink:
|
||||
response: HCI_Event
|
||||
response: HCI_Event | None
|
||||
|
||||
def __init__(self, source: Source, response: HCI_Event) -> None:
|
||||
def __init__(self, source: Source, response: HCI_Event | None) -> None:
|
||||
self.source = source
|
||||
self.response = response
|
||||
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.source.sink.on_packet(bytes(self.response))
|
||||
if self.response is not None:
|
||||
self.source.sink.on_packet(bytes(self.response))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -228,6 +229,23 @@ async def test_send_sync_command() -> None:
|
||||
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_sync_command_timeout() -> None:
|
||||
source = Source()
|
||||
sink = Sink(source, None)
|
||||
|
||||
host = Host(source, sink)
|
||||
host.ready = True
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
|
||||
|
||||
# The sending semaphore should have been released, so this should not block
|
||||
# indefinitely
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_async_command() -> None:
|
||||
source = Source()
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from bumble import sdp
|
||||
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
|
||||
from bumble.sdp import (
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
@@ -206,6 +208,16 @@ def sdp_records(record_count=1):
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_pdu_parameter_length(caplog) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
pdu = sdp.SDP_ErrorResponse(
|
||||
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
|
||||
)
|
||||
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
|
||||
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_search():
|
||||
@@ -428,3 +440,43 @@ async def run():
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_nested_sequence_recursion_guard():
|
||||
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
|
||||
the parser with RecursionError. Instead a ValueError is raised once the
|
||||
configured nesting limit is exceeded.
|
||||
|
||||
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
|
||||
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
|
||||
without a depth limit. A malicious SDP peer could craft a PDU exceeding
|
||||
Pythons default recursion limit (~1000 frames) and crash the host.
|
||||
"""
|
||||
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
|
||||
inner = b"\x35\x00" # empty SEQUENCE terminator
|
||||
for _ in range(1500):
|
||||
size = len(inner)
|
||||
if size >= 65535:
|
||||
break
|
||||
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
|
||||
|
||||
with pytest.raises(ValueError, match="nesting exceeds max depth"):
|
||||
DataElement.from_bytes(inner)
|
||||
|
||||
|
||||
def test_nested_sequence_within_limit_still_works():
|
||||
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
|
||||
leaf = DataElement.unsigned_integer(1, value_size=2)
|
||||
payload = leaf
|
||||
for _ in range(16): # under the 32-depth limit
|
||||
payload = DataElement.sequence([payload])
|
||||
raw = bytes(payload)
|
||||
parsed = DataElement.from_bytes(raw)
|
||||
# Walk back down to confirm structural integrity preserved
|
||||
cur = parsed
|
||||
for _ in range(16):
|
||||
assert cur.type == DataElement.SEQUENCE
|
||||
cur = cur.value[0]
|
||||
assert cur.type == DataElement.UNSIGNED_INTEGER
|
||||
assert cur.value == 1
|
||||
|
||||
Reference in New Issue
Block a user