Merge pull request #760 from zxzxwu/ipv6

Enhance transports
This commit is contained in:
zxzxwu
2025-08-21 14:31:50 +08:00
committed by GitHub
10 changed files with 268 additions and 51 deletions

View File

@@ -185,12 +185,18 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
return await open_android_netsim_transport(spec)
if scheme == 'unix':
if scheme in ('unix', 'unix-client'):
from bumble.transport.unix import open_unix_client_transport
assert spec
return await open_unix_client_transport(spec)
if scheme == 'unix-server':
from bumble.transport.unix import open_unix_server_transport
assert spec
return await open_unix_server_transport(spec)
raise TransportSpecError('unknown transport scheme')

View File

@@ -77,21 +77,17 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
# Parse the parameters
mode = 'host'
server_host = 'localhost'
server_port = '8554'
server_address = 'localhost:8554'
if spec:
params = spec.split(',')
for param in params:
if param.startswith('mode='):
mode = param.split('=')[1]
elif ':' in param:
server_host, server_port = param.split(':')
else:
raise TransportSpecError('invalid parameter')
server_address = param
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'connecting to gRPC server at {server_address}')
logger.debug('connecting to gRPC server at %s', server_address)
channel = grpc.aio.insecure_channel(server_address)
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]

View File

@@ -145,8 +145,6 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: dict[str, str]
) -> Transport:
if not server_port:
raise TransportSpecError('invalid port')
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -168,14 +166,16 @@ async def open_android_netsim_controller_transport(
await self.pump_loop()
except asyncio.CancelledError:
logger.debug('Pump task canceled')
self.done.set_result(None)
if not self.done.done():
self.done.set_result(None)
async def pump_loop(self):
while True:
request = await self.context.read()
if request == grpc.aio.EOF:
logger.debug('End of request stream')
self.done.set_result(None)
if not self.done.done():
self.done.set_result(None)
return
# If we're not initialized yet, wait for a init packet.
@@ -220,6 +220,8 @@ async def open_android_netsim_controller_transport(
async def wait_for_termination(self):
await self.done
server_address = f'{server_host}:{server_port}'
class Server(PacketStreamerServicer, ParserSource):
def __init__(self):
PacketStreamerServicer.__init__(self)
@@ -230,8 +232,8 @@ async def open_android_netsim_controller_transport(
# a server listening on that port, we get an exception.
self.grpc_server = grpc.aio.server(options=(('grpc.so_reuseport', 0),))
add_PacketStreamerServicer_to_server(self, self.grpc_server)
self.grpc_server.add_insecure_port(f'{server_host}:{server_port}')
logger.debug(f'gRPC server listening on {server_host}:{server_port}')
self.port = self.grpc_server.add_insecure_port(server_address)
logger.debug('gRPC server listening on %s', server_address)
async def start(self):
logger.debug('Starting gRPC server')
@@ -443,7 +445,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
params = spec.split(',') if spec else []
if params and ':' in params[0]:
# Explicit <host>:<port>
host, port_str = params[0].split(':')
host, port_str = params[0].rsplit(':', maxsplit=1)
port = int(port_str)
params_offset = 1
else:

View File

@@ -41,7 +41,7 @@ async def open_tcp_client_transport(spec: str) -> Transport:
logger.debug(f'connection lost: {exc}')
self.on_transport_lost()
remote_host, remote_port = spec.split(':')
remote_host, remote_port = spec.rsplit(':', maxsplit=1)
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
TcpPacketSource,
host=remote_host,

View File

@@ -29,13 +29,6 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# A pass-through function to ease mock testing.
async def _create_server(*args, **kw_args):
await asyncio.get_running_loop().create_server(*args, **kw_args)
async def open_tcp_server_transport(spec: str) -> Transport:
'''
Open a TCP server transport.
@@ -46,13 +39,15 @@ async def open_tcp_server_transport(spec: str) -> Transport:
Example: _:9001
'''
local_host, local_port = spec.split(':')
local_host, local_port = spec.rsplit(':', maxsplit=1)
return await _open_tcp_server_transport_impl(
host=local_host if local_host != '_' else None, port=int(local_port)
)
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
async def open_tcp_server_transport_with_socket(
sock: socket.socket,
) -> Transport:
'''
Open a TCP server transport with an existing socket.
@@ -63,8 +58,9 @@ async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transpor
async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
class TcpServerTransport(Transport):
async def close(self):
await super().close()
def __init__(self, source, sink, server):
self.server = server
super().__init__(source, sink)
class TcpServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink):
@@ -102,8 +98,8 @@ async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink()
await _create_server(
server = await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
)
return TcpServerTransport(packet_source, packet_sink)
return TcpServerTransport(packet_source, packet_sink, server)

View File

@@ -51,8 +51,8 @@ async def open_udp_transport(spec: str) -> Transport:
self.transport.close()
local, remote = spec.split(',')
local_host, local_port = local.split(':')
remote_host, remote_port = remote.split(':')
local_host, local_port = local.rsplit(':', maxsplit=1)
remote_host, remote_port = remote.rsplit(':', maxsplit=1)
(
udp_transport,
packet_source,

View File

@@ -54,3 +54,69 @@ async def open_unix_client_transport(spec: str) -> Transport:
packet_sink = StreamPacketSink(unix_transport)
return Transport(packet_source, packet_sink)
# -----------------------------------------------------------------------------
async def open_unix_server_transport(spec: str) -> Transport:
'''Open a UNIX socket server transport.
The parameter is the path of unix socket. For abstract socket, the first character
needs to be '@'.
Example:
* /tmp/hci.socket
* @hci_socket
'''
# For abstract socket, the first character should be null character.
if spec.startswith('@'):
spec = '\0' + spec[1:]
class UnixServerTransport(Transport):
def __init__(self, source, sink, server):
self.server = server
super().__init__(source, sink)
async def close(self):
await super().close()
class UnixServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source
self.packet_sink = packet_sink
# Called when a new connection is established
def connection_made(self, transport):
peer_name = transport.get_extra_info('peer_name')
logger.debug('connection from %s', peer_name)
self.packet_sink.transport = transport
# Called when the client is disconnected
def connection_lost(self, error):
logger.debug('connection lost: %s', error)
self.packet_sink.transport = None
def eof_received(self):
logger.debug('connection end')
self.packet_sink.transport = None
# Called when data is received on the socket
def data_received(self, data):
self.packet_source.data_received(data)
class UnixServerPacketSink:
def __init__(self):
self.transport = None
def on_packet(self, packet):
if self.transport:
self.transport.write(packet)
else:
logger.debug('no client, dropping packet')
packet_source = StreamPacketSource()
packet_sink = UnixServerPacketSink()
server = await asyncio.get_running_loop().create_unix_server(
lambda: UnixServerProtocol(packet_source, packet_sink), spec
)
return UnixServerTransport(packet_source, packet_sink, server)

View File

@@ -82,7 +82,7 @@ async def open_ws_server_transport(spec: str) -> Transport:
return
return await self.connection.send(packet)
local_host, local_port = spec.split(':')
local_host, local_port = spec.rsplit(':', maxsplit=1)
transport = WsServerTransport()
await transport.serve(local_host, local_port)
return transport

View File

@@ -16,8 +16,7 @@ import asyncio
import os
import pytest
import socket
import unittest
from unittest.mock import ANY, patch
from unittest import mock
from bumble.transport.tcp_server import (
open_tcp_server_transport,
@@ -25,28 +24,23 @@ from bumble.transport.tcp_server import (
)
class OpenTcpServerTransportTests(unittest.TestCase):
def setUp(self):
self.patcher = patch('bumble.transport.tcp_server._create_server')
self.mock_create_server = self.patcher.start()
async def test_open_with_spec():
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
await open_tcp_server_transport('localhost:32100')
m.assert_awaited_once_with(mock.ANY, host='localhost', port=32100)
def tearDown(self):
self.patcher.stop()
def test_open_with_spec(self):
asyncio.run(open_tcp_server_transport('localhost:32100'))
self.mock_create_server.assert_awaited_once_with(
ANY, host='localhost', port=32100
)
async def test_open_with_port_only_spec():
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
await open_tcp_server_transport('_:32100')
m.assert_awaited_once_with(mock.ANY, host=None, port=32100)
def test_open_with_port_only_spec(self):
asyncio.run(open_tcp_server_transport('_:32100'))
self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100)
def test_open_with_socket(self):
async def test_open_with_socket():
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
asyncio.run(open_tcp_server_transport_with_socket(sock=sock))
self.mock_create_server.assert_awaited_once_with(ANY, sock=sock)
await open_tcp_server_transport_with_socket(sock=sock)
m.assert_awaited_once_with(mock.ANY, sock=sock)
@pytest.mark.skipif(

View File

@@ -17,9 +17,41 @@
# -----------------------------------------------------------------------------
import random
import os
import socket
import sys
import pytest
from bumble import controller
from bumble import device
from bumble import hci
from bumble import link
from bumble import transport
from bumble.transport.common import PacketParser
# -----------------------------------------------------------------------------
def _make_controller_from_transport(transport: transport.Transport):
return controller.Controller(
name="server",
host_sink=transport.sink,
host_source=transport.source,
link=link.LocalLink(),
)
# -----------------------------------------------------------------------------
def _make_device_from_transport(
transport: transport.Transport, address: str = "11:22:33:44:55:66"
):
return device.Device.with_hci(
name="client",
address=hci.Address(address),
hci_sink=transport.sink,
hci_source=transport.source,
)
# -----------------------------------------------------------------------------
class Sink:
def __init__(self):
@@ -71,6 +103,131 @@ def test_parser_extensions():
assert len(sink.packets) == 1
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address,",
("127.0.0.1", "::1"),
)
async def test_tcp_connection(address):
server_transport = await transport.open_transport(f"tcp-server:{address}:0")
port = server_transport.server.sockets[0].getsockname()[1]
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(f"tcp-client:{address}:{port}")
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address, family",
(("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)),
)
async def test_udp_connection(address, family):
# Pick empty ports
ports = []
for _ in range(2):
sock = socket.socket(family=family, type=socket.SOCK_DGRAM)
sock.bind((address, 0))
ports.append(sock.getsockname()[1])
sock.close()
server_transport = await transport.open_transport(
f"udp:{address}:{ports[0]},{address}:{ports[1]}"
)
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(
f"udp:{address}:{ports[1]},{address}:{ports[0]}"
)
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"server_address, client_address",
(
("127.0.0.1", "ws://127.0.0.1"),
("::1", "ws://[::1]"),
),
)
async def test_ws_connection(server_address, client_address):
server_transport = await transport.open_transport(f"ws-server:{server_address}:0")
port = server_transport.server.sockets[0].getsockname()[1]
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(
f"ws-client:{client_address}:{port}"
)
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.skipif(
sys.platform != 'linux', reason='Unix socket is only fully supported on Linux'
)
async def test_unix_connection_file(tmpdir):
path = str(tmpdir / 'bumble.sock')
server_transport = await transport.open_transport(f"unix-server:{path}")
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(f"unix-client:{path}")
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.skipif(
sys.platform != 'linux', reason='Unix socket is only fully supported on Linux'
)
async def test_unix_connection_abstract():
server_transport = await transport.open_transport("unix-server:@bumble.test.sock")
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport("unix-client:@bumble.test.sock")
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address,",
("127.0.0.1",),
)
async def test_android_netsim_connection(address):
controller_transport = await transport.open_transport(
"android-netsim:_:0,mode=controller"
)
port = controller_transport.source.port
_make_controller_from_transport(controller_transport)
client_transport = await transport.open_transport(
f"android-netsim:{address}:{port},mode=host"
)
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await controller_transport.close()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_parser()