Merge pull request #760 from zxzxwu/ipv6

Enhance transports
This commit is contained in:
zxzxwu
2025-08-21 14:31:50 +08:00
committed by GitHub
10 changed files with 268 additions and 51 deletions

View File

@@ -185,12 +185,18 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
return await open_android_netsim_transport(spec) return await open_android_netsim_transport(spec)
if scheme == 'unix': if scheme in ('unix', 'unix-client'):
from bumble.transport.unix import open_unix_client_transport from bumble.transport.unix import open_unix_client_transport
assert spec assert spec
return await open_unix_client_transport(spec) return await open_unix_client_transport(spec)
if scheme == 'unix-server':
from bumble.transport.unix import open_unix_server_transport
assert spec
return await open_unix_server_transport(spec)
raise TransportSpecError('unknown transport scheme') raise TransportSpecError('unknown transport scheme')

View File

@@ -77,21 +77,17 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
server_host = 'localhost' server_address = 'localhost:8554'
server_port = '8554'
if spec: if spec:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
if param.startswith('mode='): if param.startswith('mode='):
mode = param.split('=')[1] mode = param.split('=')[1]
elif ':' in param:
server_host, server_port = param.split(':')
else: else:
raise TransportSpecError('invalid parameter') server_address = param
# Connect to the gRPC server # Connect to the gRPC server
server_address = f'{server_host}:{server_port}' logger.debug('connecting to gRPC server at %s', server_address)
logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub] service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]

View File

@@ -145,8 +145,6 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
async def open_android_netsim_controller_transport( async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: dict[str, str] server_host: Optional[str], server_port: int, options: dict[str, str]
) -> Transport: ) -> Transport:
if not server_port:
raise TransportSpecError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:
server_host = 'localhost' server_host = 'localhost'
@@ -168,14 +166,16 @@ async def open_android_netsim_controller_transport(
await self.pump_loop() await self.pump_loop()
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug('Pump task canceled') logger.debug('Pump task canceled')
self.done.set_result(None) if not self.done.done():
self.done.set_result(None)
async def pump_loop(self): async def pump_loop(self):
while True: while True:
request = await self.context.read() request = await self.context.read()
if request == grpc.aio.EOF: if request == grpc.aio.EOF:
logger.debug('End of request stream') logger.debug('End of request stream')
self.done.set_result(None) if not self.done.done():
self.done.set_result(None)
return return
# If we're not initialized yet, wait for a init packet. # If we're not initialized yet, wait for a init packet.
@@ -220,6 +220,8 @@ async def open_android_netsim_controller_transport(
async def wait_for_termination(self): async def wait_for_termination(self):
await self.done await self.done
server_address = f'{server_host}:{server_port}'
class Server(PacketStreamerServicer, ParserSource): class Server(PacketStreamerServicer, ParserSource):
def __init__(self): def __init__(self):
PacketStreamerServicer.__init__(self) PacketStreamerServicer.__init__(self)
@@ -230,8 +232,8 @@ async def open_android_netsim_controller_transport(
# a server listening on that port, we get an exception. # a server listening on that port, we get an exception.
self.grpc_server = grpc.aio.server(options=(('grpc.so_reuseport', 0),)) self.grpc_server = grpc.aio.server(options=(('grpc.so_reuseport', 0),))
add_PacketStreamerServicer_to_server(self, self.grpc_server) add_PacketStreamerServicer_to_server(self, self.grpc_server)
self.grpc_server.add_insecure_port(f'{server_host}:{server_port}') self.port = self.grpc_server.add_insecure_port(server_address)
logger.debug(f'gRPC server listening on {server_host}:{server_port}') logger.debug('gRPC server listening on %s', server_address)
async def start(self): async def start(self):
logger.debug('Starting gRPC server') logger.debug('Starting gRPC server')
@@ -443,7 +445,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
params = spec.split(',') if spec else [] params = spec.split(',') if spec else []
if params and ':' in params[0]: if params and ':' in params[0]:
# Explicit <host>:<port> # Explicit <host>:<port>
host, port_str = params[0].split(':') host, port_str = params[0].rsplit(':', maxsplit=1)
port = int(port_str) port = int(port_str)
params_offset = 1 params_offset = 1
else: else:

View File

@@ -41,7 +41,7 @@ async def open_tcp_client_transport(spec: str) -> Transport:
logger.debug(f'connection lost: {exc}') logger.debug(f'connection lost: {exc}')
self.on_transport_lost() self.on_transport_lost()
remote_host, remote_port = spec.split(':') remote_host, remote_port = spec.rsplit(':', maxsplit=1)
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection( tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
TcpPacketSource, TcpPacketSource,
host=remote_host, host=remote_host,

View File

@@ -29,13 +29,6 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# A pass-through function to ease mock testing.
async def _create_server(*args, **kw_args):
await asyncio.get_running_loop().create_server(*args, **kw_args)
async def open_tcp_server_transport(spec: str) -> Transport: async def open_tcp_server_transport(spec: str) -> Transport:
''' '''
Open a TCP server transport. Open a TCP server transport.
@@ -46,13 +39,15 @@ async def open_tcp_server_transport(spec: str) -> Transport:
Example: _:9001 Example: _:9001
''' '''
local_host, local_port = spec.split(':') local_host, local_port = spec.rsplit(':', maxsplit=1)
return await _open_tcp_server_transport_impl( return await _open_tcp_server_transport_impl(
host=local_host if local_host != '_' else None, port=int(local_port) host=local_host if local_host != '_' else None, port=int(local_port)
) )
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport: async def open_tcp_server_transport_with_socket(
sock: socket.socket,
) -> Transport:
''' '''
Open a TCP server transport with an existing socket. Open a TCP server transport with an existing socket.
@@ -63,8 +58,9 @@ async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transpor
async def _open_tcp_server_transport_impl(**kwargs) -> Transport: async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
class TcpServerTransport(Transport): class TcpServerTransport(Transport):
async def close(self): def __init__(self, source, sink, server):
await super().close() self.server = server
super().__init__(source, sink)
class TcpServerProtocol(asyncio.BaseProtocol): class TcpServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink): def __init__(self, packet_source, packet_sink):
@@ -102,8 +98,8 @@ async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
packet_source = StreamPacketSource() packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink() packet_sink = TcpServerPacketSink()
await _create_server( server = await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
) )
return TcpServerTransport(packet_source, packet_sink) return TcpServerTransport(packet_source, packet_sink, server)

View File

@@ -51,8 +51,8 @@ async def open_udp_transport(spec: str) -> Transport:
self.transport.close() self.transport.close()
local, remote = spec.split(',') local, remote = spec.split(',')
local_host, local_port = local.split(':') local_host, local_port = local.rsplit(':', maxsplit=1)
remote_host, remote_port = remote.split(':') remote_host, remote_port = remote.rsplit(':', maxsplit=1)
( (
udp_transport, udp_transport,
packet_source, packet_source,

View File

@@ -54,3 +54,69 @@ async def open_unix_client_transport(spec: str) -> Transport:
packet_sink = StreamPacketSink(unix_transport) packet_sink = StreamPacketSink(unix_transport)
return Transport(packet_source, packet_sink) return Transport(packet_source, packet_sink)
# -----------------------------------------------------------------------------
async def open_unix_server_transport(spec: str) -> Transport:
'''Open a UNIX socket server transport.
The parameter is the path of unix socket. For abstract socket, the first character
needs to be '@'.
Example:
* /tmp/hci.socket
* @hci_socket
'''
# For abstract socket, the first character should be null character.
if spec.startswith('@'):
spec = '\0' + spec[1:]
class UnixServerTransport(Transport):
def __init__(self, source, sink, server):
self.server = server
super().__init__(source, sink)
async def close(self):
await super().close()
class UnixServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source
self.packet_sink = packet_sink
# Called when a new connection is established
def connection_made(self, transport):
peer_name = transport.get_extra_info('peer_name')
logger.debug('connection from %s', peer_name)
self.packet_sink.transport = transport
# Called when the client is disconnected
def connection_lost(self, error):
logger.debug('connection lost: %s', error)
self.packet_sink.transport = None
def eof_received(self):
logger.debug('connection end')
self.packet_sink.transport = None
# Called when data is received on the socket
def data_received(self, data):
self.packet_source.data_received(data)
class UnixServerPacketSink:
def __init__(self):
self.transport = None
def on_packet(self, packet):
if self.transport:
self.transport.write(packet)
else:
logger.debug('no client, dropping packet')
packet_source = StreamPacketSource()
packet_sink = UnixServerPacketSink()
server = await asyncio.get_running_loop().create_unix_server(
lambda: UnixServerProtocol(packet_source, packet_sink), spec
)
return UnixServerTransport(packet_source, packet_sink, server)

View File

@@ -82,7 +82,7 @@ async def open_ws_server_transport(spec: str) -> Transport:
return return
return await self.connection.send(packet) return await self.connection.send(packet)
local_host, local_port = spec.split(':') local_host, local_port = spec.rsplit(':', maxsplit=1)
transport = WsServerTransport() transport = WsServerTransport()
await transport.serve(local_host, local_port) await transport.serve(local_host, local_port)
return transport return transport

View File

@@ -16,8 +16,7 @@ import asyncio
import os import os
import pytest import pytest
import socket import socket
import unittest from unittest import mock
from unittest.mock import ANY, patch
from bumble.transport.tcp_server import ( from bumble.transport.tcp_server import (
open_tcp_server_transport, open_tcp_server_transport,
@@ -25,28 +24,23 @@ from bumble.transport.tcp_server import (
) )
class OpenTcpServerTransportTests(unittest.TestCase): async def test_open_with_spec():
def setUp(self): with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
self.patcher = patch('bumble.transport.tcp_server._create_server') await open_tcp_server_transport('localhost:32100')
self.mock_create_server = self.patcher.start() m.assert_awaited_once_with(mock.ANY, host='localhost', port=32100)
def tearDown(self):
self.patcher.stop()
def test_open_with_spec(self): async def test_open_with_port_only_spec():
asyncio.run(open_tcp_server_transport('localhost:32100')) with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
self.mock_create_server.assert_awaited_once_with( await open_tcp_server_transport('_:32100')
ANY, host='localhost', port=32100 m.assert_awaited_once_with(mock.ANY, host=None, port=32100)
)
def test_open_with_port_only_spec(self):
asyncio.run(open_tcp_server_transport('_:32100'))
self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100)
def test_open_with_socket(self): async def test_open_with_socket():
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
asyncio.run(open_tcp_server_transport_with_socket(sock=sock)) await open_tcp_server_transport_with_socket(sock=sock)
self.mock_create_server.assert_awaited_once_with(ANY, sock=sock) m.assert_awaited_once_with(mock.ANY, sock=sock)
@pytest.mark.skipif( @pytest.mark.skipif(

View File

@@ -17,9 +17,41 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import random import random
import os import os
import socket
import sys
import pytest
from bumble import controller
from bumble import device
from bumble import hci
from bumble import link
from bumble import transport
from bumble.transport.common import PacketParser from bumble.transport.common import PacketParser
# -----------------------------------------------------------------------------
def _make_controller_from_transport(transport: transport.Transport):
return controller.Controller(
name="server",
host_sink=transport.sink,
host_source=transport.source,
link=link.LocalLink(),
)
# -----------------------------------------------------------------------------
def _make_device_from_transport(
transport: transport.Transport, address: str = "11:22:33:44:55:66"
):
return device.Device.with_hci(
name="client",
address=hci.Address(address),
hci_sink=transport.sink,
hci_source=transport.source,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Sink: class Sink:
def __init__(self): def __init__(self):
@@ -71,6 +103,131 @@ def test_parser_extensions():
assert len(sink.packets) == 1 assert len(sink.packets) == 1
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address,",
("127.0.0.1", "::1"),
)
async def test_tcp_connection(address):
server_transport = await transport.open_transport(f"tcp-server:{address}:0")
port = server_transport.server.sockets[0].getsockname()[1]
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(f"tcp-client:{address}:{port}")
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address, family",
(("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)),
)
async def test_udp_connection(address, family):
# Pick empty ports
ports = []
for _ in range(2):
sock = socket.socket(family=family, type=socket.SOCK_DGRAM)
sock.bind((address, 0))
ports.append(sock.getsockname()[1])
sock.close()
server_transport = await transport.open_transport(
f"udp:{address}:{ports[0]},{address}:{ports[1]}"
)
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(
f"udp:{address}:{ports[1]},{address}:{ports[0]}"
)
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"server_address, client_address",
(
("127.0.0.1", "ws://127.0.0.1"),
("::1", "ws://[::1]"),
),
)
async def test_ws_connection(server_address, client_address):
server_transport = await transport.open_transport(f"ws-server:{server_address}:0")
port = server_transport.server.sockets[0].getsockname()[1]
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(
f"ws-client:{client_address}:{port}"
)
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.skipif(
sys.platform != 'linux', reason='Unix socket is only fully supported on Linux'
)
async def test_unix_connection_file(tmpdir):
path = str(tmpdir / 'bumble.sock')
server_transport = await transport.open_transport(f"unix-server:{path}")
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport(f"unix-client:{path}")
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.skipif(
sys.platform != 'linux', reason='Unix socket is only fully supported on Linux'
)
async def test_unix_connection_abstract():
server_transport = await transport.open_transport("unix-server:@bumble.test.sock")
_make_controller_from_transport(server_transport)
client_transport = await transport.open_transport("unix-client:@bumble.test.sock")
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await server_transport.close()
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address,",
("127.0.0.1",),
)
async def test_android_netsim_connection(address):
controller_transport = await transport.open_transport(
"android-netsim:_:0,mode=controller"
)
port = controller_transport.source.port
_make_controller_from_transport(controller_transport)
client_transport = await transport.open_transport(
f"android-netsim:{address}:{port},mode=host"
)
client_device = _make_device_from_transport(client_transport)
await client_device.power_on()
await client_transport.close()
await controller_transport.close()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
test_parser() test_parser()