Compare commits

..

1 Commits

Author SHA1 Message Date
uael a7c87e7ad2 asha: import ASHA Pandora service from AOSP 2023-10-03 19:42:18 -07:00
19 changed files with 805 additions and 2383 deletions
-43
View File
@@ -1,43 +0,0 @@
name: Python Avatar
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
test:
name: Avatar [${{ matrix.shard }}]
runs-on: ubuntu-latest
strategy:
matrix:
shard: [
1/24, 2/24, 3/24, 4/24,
5/24, 6/24, 7/24, 8/24,
9/24, 10/24, 11/24, 12/24,
13/24, 14/24, 15/24, 16/24,
17/24, 18/24, 19/24, 20/24,
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v3
- name: Set Up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install .[avatar]
- name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log &
- name: Test
run: |
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
run: cat rootcanal.log
+1 -1
View File
@@ -306,7 +306,6 @@ async def pair(
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
device.le_enabled = True
device.add_service(
Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
@@ -327,6 +326,7 @@ async def pair(
# Select LE or Classic
if mode == 'classic':
device.classic_enabled = True
device.le_enabled = False
device.classic_smp_enabled = ctkd
# Get things going
+176 -302
View File
File diff suppressed because it is too large Load Diff
+173 -813
View File
File diff suppressed because it is too large Load Diff
+9 -92
View File
@@ -33,8 +33,6 @@ from typing import (
Tuple,
Type,
Union,
cast,
overload,
TYPE_CHECKING,
)
@@ -153,7 +151,6 @@ from .utils import (
CompositeEventEmitter,
setup_event_forwarding,
composite_listener,
deprecated,
)
from .keys import (
KeyStore,
@@ -673,7 +670,9 @@ class Connection(CompositeEventEmitter):
def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(self.handle, cid, pdu)
@deprecated("Please use create_l2cap_channel()")
def create_l2cap_connector(self, psm):
return self.device.create_l2cap_connector(self, psm)
async def open_l2cap_channel(
self,
psm,
@@ -683,23 +682,6 @@ class Connection(CompositeEventEmitter):
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
@overload
async def create_l2cap_channel(
self, spec: l2cap.ClassicChannelSpec
) -> l2cap.ClassicChannel:
...
@overload
async def create_l2cap_channel(
self, spec: l2cap.LeCreditBasedChannelSpec
) -> l2cap.LeCreditBasedChannel:
...
async def create_l2cap_channel(
self, spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec]
) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]:
return await self.device.create_l2cap_channel(connection=self, spec=spec)
async def disconnect(
self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
@@ -1198,11 +1180,15 @@ class Device(CompositeEventEmitter):
return None
@deprecated("Please use create_l2cap_server()")
def create_l2cap_connector(self, connection, psm):
return lambda: self.l2cap_channel_manager.connect(connection, psm)
def create_l2cap_registrar(self, psm):
return lambda handler: self.register_l2cap_server(psm, handler)
def register_l2cap_server(self, psm, server) -> int:
return self.l2cap_channel_manager.register_server(psm, server)
@deprecated("Please use create_l2cap_server()")
def register_l2cap_channel_server(
self,
psm,
@@ -1215,7 +1201,6 @@ class Device(CompositeEventEmitter):
psm, server, max_credits, mtu, mps
)
@deprecated("Please use create_l2cap_channel()")
async def open_l2cap_channel(
self,
connection,
@@ -1228,74 +1213,6 @@ class Device(CompositeEventEmitter):
connection, psm, max_credits, mtu, mps
)
@overload
async def create_l2cap_channel(
self,
connection: Connection,
spec: l2cap.ClassicChannelSpec,
) -> l2cap.ClassicChannel:
...
@overload
async def create_l2cap_channel(
self,
connection: Connection,
spec: l2cap.LeCreditBasedChannelSpec,
) -> l2cap.LeCreditBasedChannel:
...
async def create_l2cap_channel(
self,
connection: Connection,
spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec],
) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]:
if isinstance(spec, l2cap.ClassicChannelSpec):
return await self.l2cap_channel_manager.create_classic_channel(
connection=connection, spec=spec
)
if isinstance(spec, l2cap.LeCreditBasedChannelSpec):
return await self.l2cap_channel_manager.create_le_credit_based_channel(
connection=connection, spec=spec
)
@overload
def create_l2cap_server(
self,
spec: l2cap.ClassicChannelSpec,
handler: Optional[Callable[[l2cap.ClassicChannel], Any]] = None,
) -> l2cap.ClassicChannelServer:
...
@overload
def create_l2cap_server(
self,
spec: l2cap.LeCreditBasedChannelSpec,
handler: Optional[Callable[[l2cap.LeCreditBasedChannel], Any]] = None,
) -> l2cap.LeCreditBasedChannelServer:
...
def create_l2cap_server(
self,
spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec],
handler: Union[
Callable[[l2cap.ClassicChannel], Any],
Callable[[l2cap.LeCreditBasedChannel], Any],
None,
] = None,
) -> Union[l2cap.ClassicChannelServer, l2cap.LeCreditBasedChannelServer]:
if isinstance(spec, l2cap.ClassicChannelSpec):
return self.l2cap_channel_manager.create_classic_server(
spec=spec,
handler=cast(Callable[[l2cap.ClassicChannel], Any], handler),
)
elif isinstance(spec, l2cap.LeCreditBasedChannelSpec):
return self.l2cap_channel_manager.create_le_credit_based_server(
handler=cast(Callable[[l2cap.LeCreditBasedChannel], Any], handler),
spec=spec,
)
else:
raise ValueError(f'Unexpected mode {spec}')
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
+76 -219
View File
@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import logging
import struct
@@ -39,7 +38,6 @@ from typing import (
TYPE_CHECKING,
)
from .utils import deprecated
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .hci import (
@@ -169,34 +167,6 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
# pylint: disable=invalid-name
@dataclasses.dataclass
class ClassicChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_MIN_BR_EDR_MTU
@dataclasses.dataclass
class LeCreditBasedChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
def __post_init__(self):
if (
self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
class L2CAP_PDU:
'''
See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT
@@ -706,7 +676,7 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
class ClassicChannel(EventEmitter):
class Channel(EventEmitter):
class State(enum.IntEnum):
# States
CLOSED = 0x00
@@ -1020,7 +990,7 @@ class ClassicChannel(EventEmitter):
# -----------------------------------------------------------------------------
class LeCreditBasedChannel(EventEmitter):
class LeConnectionOrientedChannel(EventEmitter):
"""
LE Credit-based Connection Oriented Channel
"""
@@ -1034,7 +1004,7 @@ class LeCreditBasedChannel(EventEmitter):
CONNECTION_ERROR = 5
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeCreditBasedChannel]]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
state: State
@@ -1101,7 +1071,7 @@ class LeCreditBasedChannel(EventEmitter):
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self) -> LeCreditBasedChannel:
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state')
@@ -1372,67 +1342,15 @@ class LeCreditBasedChannel(EventEmitter):
)
# -----------------------------------------------------------------------------
class ClassicChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[ClassicChannel], Any]],
mtu: int,
) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel)
if self.handler:
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.servers:
del self.manager.servers[self.psm]
# -----------------------------------------------------------------------------
class LeCreditBasedChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
max_credits: int,
mtu: int,
mps: int,
) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel)
if self.handler:
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.le_coc_servers:
del self.manager.le_coc_servers[self.psm]
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
servers: Dict[int, ClassicChannelServer]
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
servers: Dict[int, Callable[[Channel], Any]]
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
le_coc_servers: Dict[
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
@@ -1511,6 +1429,21 @@ class ChannelManager:
raise RuntimeError('no free CID')
@staticmethod
def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
if (
max_credits < 1
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
@@ -1525,22 +1458,8 @@ class ChannelManager:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
@deprecated("Please use create_classic_channel_server")
def register_server(
self,
psm: int,
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
).psm
def create_classic_server(
self,
spec: ClassicChannelSpec,
handler: Optional[Callable[[ClassicChannel], Any]] = None,
) -> ClassicChannelServer:
if spec.psm is None:
def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int:
if psm == 0:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
@@ -1549,75 +1468,62 @@ class ChannelManager:
continue
if candidate in self.servers:
continue
spec.psm = candidate
psm = candidate
break
else:
raise InvalidStateError('no free PSM')
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
if psm in self.servers:
raise ValueError('PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
if psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
check = spec.psm >> 8
check = psm >> 8
while check:
if check % 2 != 0:
raise ValueError('invalid PSM')
check >>= 8
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
self.servers[psm] = server
return self.servers[spec.psm]
return psm
@deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server(
self,
psm: int,
server: Callable[[LeCreditBasedChannel], Any],
max_credits: int,
mtu: int,
mps: int,
server: Callable[[LeConnectionOrientedChannel], Any],
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
) -> int:
return self.create_le_credit_based_server(
spec=LeCreditBasedChannelSpec(
psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits
),
handler=server,
).psm
self.check_le_coc_parameters(max_credits, mtu, mps)
def create_le_credit_based_server(
self,
spec: LeCreditBasedChannelSpec,
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
) -> LeCreditBasedChannelServer:
if spec.psm is None:
if psm == 0:
# Find a free PSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.le_coc_servers:
continue
spec.psm = candidate
psm = candidate
break
else:
raise InvalidStateError('no free PSM')
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
if psm in self.le_coc_servers:
raise ValueError('PSM already in use')
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
spec.psm,
handler,
max_credits=spec.max_credits,
mtu=spec.mtu,
mps=spec.mps,
self.le_coc_servers[psm] = (
server,
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
)
return self.le_coc_servers[spec.psm]
return psm
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
@@ -1744,13 +1650,13 @@ class ChannelManager:
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.mtu
channel = Channel(
self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU
)
connection_channels[source_cid] = channel
# Notify
server.on_connection(channel)
server(channel)
channel.on_connection_request(request)
else:
logger.warning(
@@ -1972,7 +1878,7 @@ class ChannelManager:
self, connection: Connection, cid: int, request
) -> None:
if request.le_psm in self.le_coc_servers:
server = self.le_coc_servers[request.le_psm]
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
@@ -1986,8 +1892,8 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
mtu=mtu,
mps=mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
@@ -2005,8 +1911,8 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
mtu=mtu,
mps=mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
@@ -2019,18 +1925,18 @@ class ChannelManager:
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
channel = LeConnectionOrientedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
mtu,
mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
max_credits,
True,
)
connection_channels[source_cid] = channel
@@ -2043,16 +1949,16 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
mtu=mtu,
mps=mps,
initial_credits=max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
)
# Notify
server.on_connection(channel)
server(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
@@ -2107,51 +2013,37 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel: ClassicChannel) -> None:
def on_channel_closed(self, channel: Channel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
@deprecated("Please use create_le_credit_based_channel()")
async def open_le_coc(
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeCreditBasedChannel:
return await self.create_le_credit_based_channel(
connection=connection,
spec=LeCreditBasedChannelSpec(
psm=psm, max_credits=max_credits, mtu=mtu, mps=mps
),
)
) -> LeConnectionOrientedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps)
async def create_le_credit_based_channel(
self,
connection: Connection,
spec: LeCreditBasedChannelSpec,
) -> LeCreditBasedChannel:
# Find a free CID for the new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
# Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
channel = LeCreditBasedChannel(
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}')
channel = LeConnectionOrientedChannel(
manager=self,
connection=connection,
le_psm=spec.psm,
le_psm=psm,
source_cid=source_cid,
destination_cid=0,
mtu=spec.mtu,
mps=spec.mps,
mtu=mtu,
mps=mps,
credits=0,
peer_mtu=0,
peer_mps=0,
peer_credits=spec.max_credits,
peer_credits=max_credits,
connected=False,
)
connection_channels[source_cid] = channel
@@ -2170,15 +2062,7 @@ class ChannelManager:
return channel
@deprecated("Please use create_classic_channel()")
async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
return await self.create_classic_channel(
connection=connection, spec=ClassicChannelSpec(psm=psm)
)
async def create_classic_channel(
self, connection: Connection, spec: ClassicChannelSpec
) -> ClassicChannel:
async def connect(self, connection: Connection, psm: int) -> Channel:
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel
@@ -2187,20 +2071,10 @@ class ChannelManager:
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
# Create the channel
logger.debug(
f'creating client channel with cid={source_cid} for psm {spec.psm}'
)
channel = ClassicChannel(
self,
connection,
L2CAP_SIGNALING_CID,
spec.psm,
source_cid,
spec.mtu,
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}')
channel = Channel(
self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU
)
connection_channels[source_cid] = channel
@@ -2212,20 +2086,3 @@ class ChannelManager:
raise e
return channel
# -----------------------------------------------------------------------------
# Deprecated Classes
# -----------------------------------------------------------------------------
class Channel(ClassicChannel):
@deprecated("Please use ClassicChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class LeConnectionOrientedChannel(LeCreditBasedChannel):
@deprecated("Please use LeCreditBasedChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
+3
View File
@@ -24,8 +24,10 @@ import grpc.aio
from .config import Config
from .device import PandoraDevice
from .asha import AshaService
from .host import HostService
from .security import SecurityService, SecurityStorageService
from pandora.asha_grpc_aio import add_ASHAServicer_to_server
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
@@ -68,6 +70,7 @@ async def serve(
config.load_from_dict(bumble.config.get('server', {}))
# add Pandora services to the gRPC server.
add_ASHAServicer_to_server(AshaService(bumble.device), server)
add_HostServicer_to_server(
HostService(server, bumble.device, config), server
)
+96
View File
@@ -0,0 +1,96 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import grpc
import logging
from bumble.decoder import G722Decoder
from bumble.device import Connection, Device
from bumble.pandora import utils
from bumble.profiles import asha_service
from google.protobuf.empty_pb2 import Empty # pytype: disable=pyi-error
from pandora.asha_grpc_aio import ASHAServicer
from pandora.asha_pb2 import CaptureAudioRequest, CaptureAudioResponse, RegisterRequest
from typing import AsyncGenerator, Optional
class AshaService(ASHAServicer):
DECODE_FRAME_LENGTH = 80
device: Device
asha_service: Optional[asha_service.AshaService]
def __init__(self, device: Device) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {"service_name": "Asha", "device": device}
)
self.device = device
self.asha_service = None
@utils.rpc
async def Register(
self, request: RegisterRequest, context: grpc.ServicerContext
) -> Empty:
logging.info("Register")
if self.asha_service:
self.asha_service.capability = request.capability
self.asha_service.hisyncid = request.hisyncid
else:
self.asha_service = asha_service.AshaService(
request.capability, request.hisyncid, self.device
)
self.device.add_service(self.asha_service) # type: ignore[no-untyped-call]
return Empty()
@utils.rpc
async def CaptureAudio(
self, request: CaptureAudioRequest, context: grpc.ServicerContext
) -> AsyncGenerator[CaptureAudioResponse, None]:
connection_handle = int.from_bytes(request.connection.cookie.value, "big")
logging.info(f"CaptureAudioData connection_handle:{connection_handle}")
if not (connection := self.device.lookup_connection(connection_handle)):
raise RuntimeError(
f"Unknown connection for connection_handle:{connection_handle}"
)
decoder = G722Decoder() # type: ignore
queue: asyncio.Queue[bytes] = asyncio.Queue()
def on_data(asha_connection: Connection, data: bytes) -> None:
if asha_connection == connection:
queue.put_nowait(data)
self.asha_service.on("data", on_data) # type: ignore
try:
while data := await queue.get():
output_bytes = bytearray()
# First byte is sequence number, last 160 bytes are audio payload.
audio_payload = data[1:]
data_length = int(len(audio_payload) / AshaService.DECODE_FRAME_LENGTH)
for i in range(0, data_length):
input_data = audio_payload[
i
* AshaService.DECODE_FRAME_LENGTH : i
* AshaService.DECODE_FRAME_LENGTH
+ AshaService.DECODE_FRAME_LENGTH
]
decoded_data = decoder.decode_frame(input_data)
output_bytes.extend(decoded_data)
yield CaptureAudioResponse(data=bytes(output_bytes))
finally:
self.asha_service.remove_listener("data", on_data) # type: ignore
+13 -10
View File
@@ -450,18 +450,21 @@ class SecurityService(SecurityServicer):
'security_request': pair,
}
with contextlib.closing(EventWatcher()) as watcher:
# register event handlers
for event, listener in listeners.items():
watcher.on(connection, event, listener)
# register event handlers
for event, listener in listeners.items():
connection.on(event, listener)
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
# remove event handlers
for event, listener in listeners.items():
connection.remove_listener(event, listener) # type: ignore
# wait for `authenticate` to finish if any
if authenticate_task is not None:
+57 -44
View File
@@ -32,6 +32,7 @@ from ..gatt import (
Characteristic,
CharacteristicValue,
)
from ..l2cap import Channel
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
@@ -52,46 +53,48 @@ class AshaService(TemplateService):
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
def __init__(
self, capability: int, hisyncid: List[int], device: Device, psm: int = 0
) -> None:
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
self.audio_out_data = b""
self.psm: int = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
def on_volume_write(connection: Connection, value: bytes) -> None:
logger.info(f"--- VOLUME Write:{value[0]}")
self.emit("volume", connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
def on_audio_control_point_write(connection: Connection, value: bytes) -> None:
logger.info(f"--- AUDIO CONTROL POINT Write:{value.hex()}")
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
audio_type = ("Unknown", "Ringtone", "Phone Call", "Media")[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
f"### START: codec={value[1]}, "
f"audio_type={audio_type}, "
f"volume={value[3]}, "
f"otherstate={value[4]}"
)
self.emit(
'start',
"start",
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
"codec": value[1],
"audiotype": value[2],
"volume": value[3],
"otherstate": value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
logger.info("### STOP")
self.emit("stop", connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
logger.info(f"### STATUS: connected={value[1]}")
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
@@ -101,49 +104,59 @@ class AshaService(TemplateService):
)
)
def on_read_only_properties_read(connection: Connection) -> bytes:
value = (
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID)
)
self.emit("read_only_properties", connection, value)
return value
def on_le_psm_out_read(connection: Connection) -> bytes:
self.emit("le_psm_out", connection, self.psm)
return struct.pack("<H", self.psm)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
CharacteristicValue(read=on_read_only_properties_read),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
def on_coc(channel: Channel) -> None:
def on_data(data: bytes) -> None:
logging.debug(f"data received:{data.hex()}")
self.emit('data', channel.connection, data)
self.emit("data", channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
@@ -152,9 +165,9 @@ class AshaService(TemplateService):
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
CharacteristicValue(read=on_le_psm_out_read),
)
characteristics = [
@@ -167,7 +180,7 @@ class AshaService(TemplateService):
super().__init__(characteristics)
def get_advertising_data(self):
def get_advertising_data(self) -> bytes:
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
+4 -4
View File
@@ -674,7 +674,7 @@ class Multiplexer(EventEmitter):
acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None:
super().__init__()
self.role = role
self.l2cap_channel = l2cap_channel
@@ -887,7 +887,7 @@ class Multiplexer(EventEmitter):
# -----------------------------------------------------------------------------
class Client:
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel]
l2cap_channel: Optional[l2cap.Channel]
def __init__(self, device: Device, connection: Connection) -> None:
self.device = device
@@ -960,11 +960,11 @@ class Server(EventEmitter):
self.acceptors[channel] = acceptor
return channel
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
# Create a new multiplexer for the channel
+2 -2
View File
@@ -758,7 +758,7 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
class Client:
channel: Optional[l2cap.ClassicChannel]
channel: Optional[l2cap.Channel]
def __init__(self, device: Device) -> None:
self.device = device
@@ -921,7 +921,7 @@ class Client:
# -----------------------------------------------------------------------------
class Server:
CONTINUATION_STATE = bytes([0x01, 0x43])
channel: Optional[l2cap.ClassicChannel]
channel: Optional[l2cap.Channel]
Service = NewType('Service', List[ServiceAttribute])
service_records: Dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]]
-17
View File
@@ -21,7 +21,6 @@ import logging
import traceback
import collections
import sys
import warnings
from typing import (
Awaitable,
Set,
@@ -428,19 +427,3 @@ def wrap_async(function):
Wraps the provided function in an async function.
"""
return partial(async_call, function)
def deprecated(msg: str):
"""
Throw deprecation warning before execution
"""
def wrapper(function):
@wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
return inner
return wrapper
+3 -1
View File
@@ -70,7 +70,9 @@ async fn main() -> PyResult<()> {
let mut seen_adv_cache = seen_adv_clone.lock().unwrap();
let expiry_duration = time::Duration::from_secs(cli.dedup_expiry_secs);
let advs_from_addr = seen_adv_cache.entry(addr_bytes).or_default();
let advs_from_addr = seen_adv_cache
.entry(addr_bytes)
.or_insert_with(collections::HashMap::new);
// we expect cache hits to be the norm, so we do a separate lookup to avoid cloning
// on every lookup with entry()
let show = if let Some(prev) = advs_from_addr.get_mut(&data_units) {
+5 -2
View File
@@ -143,7 +143,10 @@ pub(crate) fn probe(verbose: bool) -> anyhow::Result<()> {
);
if let Some(s) = serial {
println!("{:26}{}", " Serial:".green(), s);
device_serials_by_id.entry(device_id).or_default().insert(s);
device_serials_by_id
.entry(device_id)
.or_insert(HashSet::new())
.insert(s);
}
if let Some(m) = mfg {
println!("{:26}{}", " Manufacturer:".green(), m);
@@ -311,7 +314,7 @@ impl ClassInfo {
self.protocol,
self.protocol_name()
.map(|s| format!(" [{}]", s))
.unwrap_or_default()
.unwrap_or_else(String::new)
)
}
}
File diff suppressed because it is too large Load Diff
-4
View File
@@ -91,13 +91,9 @@ development =
mypy == 1.5.0
nox >= 2022
pylint == 2.15.8
pyyaml >= 6.0
types-appdirs >= 1.4.3
types-invoke >= 1.7.3
types-protobuf >= 4.21.0
avatar =
pandora-avatar == 0.0.5
rootcanal == 1.3.0 ; python_version>='3.10'
documentation =
mkdocs >= 1.4.0
mkdocs-material >= 8.5.6
+2 -4
View File
@@ -45,14 +45,12 @@ def test_messages():
]
message = Get_Capabilities_Response(capabilities)
parsed = Message.create(
AVDTP_GET_CAPABILITIES, Message.MessageType.RESPONSE_ACCEPT, message.payload
AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload
)
assert message.payload == parsed.payload
message = Set_Configuration_Command(3, 4, capabilities)
parsed = Message.create(
AVDTP_SET_CONFIGURATION, Message.MessageType.COMMAND, message.payload
)
parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload)
assert message.payload == parsed.payload
+13 -13
View File
@@ -14,25 +14,25 @@
# -----------------------------------------------------------------------------
# This script generates a python-syntax list of dictionary entries for the
# company IDs listed at:
# https://bitbucket.org/bluetooth-SIG/public/src/main/assigned_numbers/company_identifiers/company_identifiers.yaml
# The input to this script is the YAML file that can be obtained at that URL
# company IDs listed at: https://www.bluetooth.com/specifications/assigned-numbers/company-identifiers/
# The input to this script is the CSV file that can be obtained at that URL
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import sys
import yaml
import csv
# -----------------------------------------------------------------------------
with open(sys.argv[1], "r") as yaml_file:
root = yaml.safe_load(yaml_file)
companies = {}
for company in root["company_identifiers"]:
companies[company["value"]] = company["name"]
with open(sys.argv[1], newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
lines = []
for row in reader:
if len(row) == 3 and row[1].startswith('0x'):
company_id = row[1]
company_name = row[2]
escaped_company_name = company_name.replace('"', '\\"')
lines.append(f' {company_id}: "{escaped_company_name}"')
for company_id in sorted(companies.keys()):
company_name = companies[company_id]
escaped_company_name = company_name.replace('"', '\\"')
print(f' 0x{company_id:04X}: "{escaped_company_name}",')
print(',\n'.join(reversed(lines)))