forked from auracaster/bumble_mirror
@@ -185,12 +185,18 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
if scheme == 'unix':
|
||||
if scheme in ('unix', 'unix-client'):
|
||||
from bumble.transport.unix import open_unix_client_transport
|
||||
|
||||
assert spec
|
||||
return await open_unix_client_transport(spec)
|
||||
|
||||
if scheme == 'unix-server':
|
||||
from bumble.transport.unix import open_unix_server_transport
|
||||
|
||||
assert spec
|
||||
return await open_unix_server_transport(spec)
|
||||
|
||||
raise TransportSpecError('unknown transport scheme')
|
||||
|
||||
|
||||
|
||||
@@ -77,21 +77,17 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
|
||||
# Parse the parameters
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = '8554'
|
||||
server_address = 'localhost:8554'
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
mode = param.split('=')[1]
|
||||
elif ':' in param:
|
||||
server_host, server_port = param.split(':')
|
||||
else:
|
||||
raise TransportSpecError('invalid parameter')
|
||||
server_address = param
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'connecting to gRPC server at {server_address}')
|
||||
logger.debug('connecting to gRPC server at %s', server_address)
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
|
||||
|
||||
@@ -145,8 +145,6 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
|
||||
async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int, options: dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise TransportSpecError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
@@ -168,14 +166,16 @@ async def open_android_netsim_controller_transport(
|
||||
await self.pump_loop()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('Pump task canceled')
|
||||
self.done.set_result(None)
|
||||
if not self.done.done():
|
||||
self.done.set_result(None)
|
||||
|
||||
async def pump_loop(self):
|
||||
while True:
|
||||
request = await self.context.read()
|
||||
if request == grpc.aio.EOF:
|
||||
logger.debug('End of request stream')
|
||||
self.done.set_result(None)
|
||||
if not self.done.done():
|
||||
self.done.set_result(None)
|
||||
return
|
||||
|
||||
# If we're not initialized yet, wait for a init packet.
|
||||
@@ -220,6 +220,8 @@ async def open_android_netsim_controller_transport(
|
||||
async def wait_for_termination(self):
|
||||
await self.done
|
||||
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
|
||||
class Server(PacketStreamerServicer, ParserSource):
|
||||
def __init__(self):
|
||||
PacketStreamerServicer.__init__(self)
|
||||
@@ -230,8 +232,8 @@ async def open_android_netsim_controller_transport(
|
||||
# a server listening on that port, we get an exception.
|
||||
self.grpc_server = grpc.aio.server(options=(('grpc.so_reuseport', 0),))
|
||||
add_PacketStreamerServicer_to_server(self, self.grpc_server)
|
||||
self.grpc_server.add_insecure_port(f'{server_host}:{server_port}')
|
||||
logger.debug(f'gRPC server listening on {server_host}:{server_port}')
|
||||
self.port = self.grpc_server.add_insecure_port(server_address)
|
||||
logger.debug('gRPC server listening on %s', server_address)
|
||||
|
||||
async def start(self):
|
||||
logger.debug('Starting gRPC server')
|
||||
@@ -443,7 +445,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
params = spec.split(',') if spec else []
|
||||
if params and ':' in params[0]:
|
||||
# Explicit <host>:<port>
|
||||
host, port_str = params[0].split(':')
|
||||
host, port_str = params[0].rsplit(':', maxsplit=1)
|
||||
port = int(port_str)
|
||||
params_offset = 1
|
||||
else:
|
||||
|
||||
@@ -41,7 +41,7 @@ async def open_tcp_client_transport(spec: str) -> Transport:
|
||||
logger.debug(f'connection lost: {exc}')
|
||||
self.on_transport_lost()
|
||||
|
||||
remote_host, remote_port = spec.split(':')
|
||||
remote_host, remote_port = spec.rsplit(':', maxsplit=1)
|
||||
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
|
||||
TcpPacketSource,
|
||||
host=remote_host,
|
||||
|
||||
@@ -29,13 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# A pass-through function to ease mock testing.
|
||||
async def _create_server(*args, **kw_args):
|
||||
await asyncio.get_running_loop().create_server(*args, **kw_args)
|
||||
|
||||
|
||||
async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport.
|
||||
@@ -46,13 +39,15 @@ async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
|
||||
Example: _:9001
|
||||
'''
|
||||
local_host, local_port = spec.split(':')
|
||||
local_host, local_port = spec.rsplit(':', maxsplit=1)
|
||||
return await _open_tcp_server_transport_impl(
|
||||
host=local_host if local_host != '_' else None, port=int(local_port)
|
||||
)
|
||||
|
||||
|
||||
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
|
||||
async def open_tcp_server_transport_with_socket(
|
||||
sock: socket.socket,
|
||||
) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport with an existing socket.
|
||||
|
||||
@@ -63,8 +58,9 @@ async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transpor
|
||||
|
||||
async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
|
||||
class TcpServerTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
def __init__(self, source, sink, server):
|
||||
self.server = server
|
||||
super().__init__(source, sink)
|
||||
|
||||
class TcpServerProtocol(asyncio.BaseProtocol):
|
||||
def __init__(self, packet_source, packet_sink):
|
||||
@@ -102,8 +98,8 @@ async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
|
||||
|
||||
packet_source = StreamPacketSource()
|
||||
packet_sink = TcpServerPacketSink()
|
||||
await _create_server(
|
||||
server = await asyncio.get_running_loop().create_server(
|
||||
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
|
||||
)
|
||||
|
||||
return TcpServerTransport(packet_source, packet_sink)
|
||||
return TcpServerTransport(packet_source, packet_sink, server)
|
||||
|
||||
@@ -51,8 +51,8 @@ async def open_udp_transport(spec: str) -> Transport:
|
||||
self.transport.close()
|
||||
|
||||
local, remote = spec.split(',')
|
||||
local_host, local_port = local.split(':')
|
||||
remote_host, remote_port = remote.split(':')
|
||||
local_host, local_port = local.rsplit(':', maxsplit=1)
|
||||
remote_host, remote_port = remote.rsplit(':', maxsplit=1)
|
||||
(
|
||||
udp_transport,
|
||||
packet_source,
|
||||
|
||||
@@ -54,3 +54,69 @@ async def open_unix_client_transport(spec: str) -> Transport:
|
||||
packet_sink = StreamPacketSink(unix_transport)
|
||||
|
||||
return Transport(packet_source, packet_sink)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_unix_server_transport(spec: str) -> Transport:
|
||||
'''Open a UNIX socket server transport.
|
||||
|
||||
The parameter is the path of unix socket. For abstract socket, the first character
|
||||
needs to be '@'.
|
||||
|
||||
Example:
|
||||
* /tmp/hci.socket
|
||||
* @hci_socket
|
||||
'''
|
||||
# For abstract socket, the first character should be null character.
|
||||
if spec.startswith('@'):
|
||||
spec = '\0' + spec[1:]
|
||||
|
||||
class UnixServerTransport(Transport):
|
||||
def __init__(self, source, sink, server):
|
||||
self.server = server
|
||||
super().__init__(source, sink)
|
||||
|
||||
async def close(self):
|
||||
await super().close()
|
||||
|
||||
class UnixServerProtocol(asyncio.BaseProtocol):
|
||||
def __init__(self, packet_source, packet_sink):
|
||||
self.packet_source = packet_source
|
||||
self.packet_sink = packet_sink
|
||||
|
||||
# Called when a new connection is established
|
||||
def connection_made(self, transport):
|
||||
peer_name = transport.get_extra_info('peer_name')
|
||||
logger.debug('connection from %s', peer_name)
|
||||
self.packet_sink.transport = transport
|
||||
|
||||
# Called when the client is disconnected
|
||||
def connection_lost(self, error):
|
||||
logger.debug('connection lost: %s', error)
|
||||
self.packet_sink.transport = None
|
||||
|
||||
def eof_received(self):
|
||||
logger.debug('connection end')
|
||||
self.packet_sink.transport = None
|
||||
|
||||
# Called when data is received on the socket
|
||||
def data_received(self, data):
|
||||
self.packet_source.data_received(data)
|
||||
|
||||
class UnixServerPacketSink:
|
||||
def __init__(self):
|
||||
self.transport = None
|
||||
|
||||
def on_packet(self, packet):
|
||||
if self.transport:
|
||||
self.transport.write(packet)
|
||||
else:
|
||||
logger.debug('no client, dropping packet')
|
||||
|
||||
packet_source = StreamPacketSource()
|
||||
packet_sink = UnixServerPacketSink()
|
||||
server = await asyncio.get_running_loop().create_unix_server(
|
||||
lambda: UnixServerProtocol(packet_source, packet_sink), spec
|
||||
)
|
||||
|
||||
return UnixServerTransport(packet_source, packet_sink, server)
|
||||
|
||||
@@ -82,7 +82,7 @@ async def open_ws_server_transport(spec: str) -> Transport:
|
||||
return
|
||||
return await self.connection.send(packet)
|
||||
|
||||
local_host, local_port = spec.split(':')
|
||||
local_host, local_port = spec.rsplit(':', maxsplit=1)
|
||||
transport = WsServerTransport()
|
||||
await transport.serve(local_host, local_port)
|
||||
return transport
|
||||
|
||||
@@ -16,8 +16,7 @@ import asyncio
|
||||
import os
|
||||
import pytest
|
||||
import socket
|
||||
import unittest
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest import mock
|
||||
|
||||
from bumble.transport.tcp_server import (
|
||||
open_tcp_server_transport,
|
||||
@@ -25,28 +24,23 @@ from bumble.transport.tcp_server import (
|
||||
)
|
||||
|
||||
|
||||
class OpenTcpServerTransportTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.patcher = patch('bumble.transport.tcp_server._create_server')
|
||||
self.mock_create_server = self.patcher.start()
|
||||
async def test_open_with_spec():
|
||||
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
|
||||
await open_tcp_server_transport('localhost:32100')
|
||||
m.assert_awaited_once_with(mock.ANY, host='localhost', port=32100)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher.stop()
|
||||
|
||||
def test_open_with_spec(self):
|
||||
asyncio.run(open_tcp_server_transport('localhost:32100'))
|
||||
self.mock_create_server.assert_awaited_once_with(
|
||||
ANY, host='localhost', port=32100
|
||||
)
|
||||
async def test_open_with_port_only_spec():
|
||||
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
|
||||
await open_tcp_server_transport('_:32100')
|
||||
m.assert_awaited_once_with(mock.ANY, host=None, port=32100)
|
||||
|
||||
def test_open_with_port_only_spec(self):
|
||||
asyncio.run(open_tcp_server_transport('_:32100'))
|
||||
self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100)
|
||||
|
||||
def test_open_with_socket(self):
|
||||
async def test_open_with_socket():
|
||||
with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
asyncio.run(open_tcp_server_transport_with_socket(sock=sock))
|
||||
self.mock_create_server.assert_awaited_once_with(ANY, sock=sock)
|
||||
await open_tcp_server_transport_with_socket(sock=sock)
|
||||
m.assert_awaited_once_with(mock.ANY, sock=sock)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@@ -17,9 +17,41 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import random
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from bumble import controller
|
||||
from bumble import device
|
||||
from bumble import hci
|
||||
from bumble import link
|
||||
from bumble import transport
|
||||
from bumble.transport.common import PacketParser
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def _make_controller_from_transport(transport: transport.Transport):
|
||||
return controller.Controller(
|
||||
name="server",
|
||||
host_sink=transport.sink,
|
||||
host_source=transport.source,
|
||||
link=link.LocalLink(),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def _make_device_from_transport(
|
||||
transport: transport.Transport, address: str = "11:22:33:44:55:66"
|
||||
):
|
||||
return device.Device.with_hci(
|
||||
name="client",
|
||||
address=hci.Address(address),
|
||||
hci_sink=transport.sink,
|
||||
hci_source=transport.source,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Sink:
|
||||
def __init__(self):
|
||||
@@ -71,6 +103,131 @@ def test_parser_extensions():
|
||||
assert len(sink.packets) == 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address,",
|
||||
("127.0.0.1", "::1"),
|
||||
)
|
||||
async def test_tcp_connection(address):
|
||||
server_transport = await transport.open_transport(f"tcp-server:{address}:0")
|
||||
port = server_transport.server.sockets[0].getsockname()[1]
|
||||
_make_controller_from_transport(server_transport)
|
||||
|
||||
client_transport = await transport.open_transport(f"tcp-client:{address}:{port}")
|
||||
client_device = _make_device_from_transport(client_transport)
|
||||
await client_device.power_on()
|
||||
|
||||
await client_transport.close()
|
||||
await server_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address, family",
|
||||
(("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)),
|
||||
)
|
||||
async def test_udp_connection(address, family):
|
||||
# Pick empty ports
|
||||
ports = []
|
||||
for _ in range(2):
|
||||
sock = socket.socket(family=family, type=socket.SOCK_DGRAM)
|
||||
sock.bind((address, 0))
|
||||
ports.append(sock.getsockname()[1])
|
||||
sock.close()
|
||||
|
||||
server_transport = await transport.open_transport(
|
||||
f"udp:{address}:{ports[0]},{address}:{ports[1]}"
|
||||
)
|
||||
_make_controller_from_transport(server_transport)
|
||||
|
||||
client_transport = await transport.open_transport(
|
||||
f"udp:{address}:{ports[1]},{address}:{ports[0]}"
|
||||
)
|
||||
client_device = _make_device_from_transport(client_transport)
|
||||
await client_device.power_on()
|
||||
|
||||
await client_transport.close()
|
||||
await server_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"server_address, client_address",
|
||||
(
|
||||
("127.0.0.1", "ws://127.0.0.1"),
|
||||
("::1", "ws://[::1]"),
|
||||
),
|
||||
)
|
||||
async def test_ws_connection(server_address, client_address):
|
||||
server_transport = await transport.open_transport(f"ws-server:{server_address}:0")
|
||||
port = server_transport.server.sockets[0].getsockname()[1]
|
||||
_make_controller_from_transport(server_transport)
|
||||
|
||||
client_transport = await transport.open_transport(
|
||||
f"ws-client:{client_address}:{port}"
|
||||
)
|
||||
client_device = _make_device_from_transport(client_transport)
|
||||
await client_device.power_on()
|
||||
|
||||
await client_transport.close()
|
||||
await server_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != 'linux', reason='Unix socket is only fully supported on Linux'
|
||||
)
|
||||
async def test_unix_connection_file(tmpdir):
|
||||
path = str(tmpdir / 'bumble.sock')
|
||||
server_transport = await transport.open_transport(f"unix-server:{path}")
|
||||
_make_controller_from_transport(server_transport)
|
||||
|
||||
client_transport = await transport.open_transport(f"unix-client:{path}")
|
||||
client_device = _make_device_from_transport(client_transport)
|
||||
await client_device.power_on()
|
||||
|
||||
await client_transport.close()
|
||||
await server_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != 'linux', reason='Unix socket is only fully supported on Linux'
|
||||
)
|
||||
async def test_unix_connection_abstract():
|
||||
server_transport = await transport.open_transport("unix-server:@bumble.test.sock")
|
||||
_make_controller_from_transport(server_transport)
|
||||
|
||||
client_transport = await transport.open_transport("unix-client:@bumble.test.sock")
|
||||
client_device = _make_device_from_transport(client_transport)
|
||||
await client_device.power_on()
|
||||
|
||||
await client_transport.close()
|
||||
await server_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address,",
|
||||
("127.0.0.1",),
|
||||
)
|
||||
async def test_android_netsim_connection(address):
|
||||
controller_transport = await transport.open_transport(
|
||||
"android-netsim:_:0,mode=controller"
|
||||
)
|
||||
port = controller_transport.source.port
|
||||
_make_controller_from_transport(controller_transport)
|
||||
|
||||
client_transport = await transport.open_transport(
|
||||
f"android-netsim:{address}:{port},mode=host"
|
||||
)
|
||||
client_device = _make_device_from_transport(client_transport)
|
||||
await client_device.power_on()
|
||||
|
||||
await client_transport.close()
|
||||
await controller_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
test_parser()
|
||||
|
||||
Reference in New Issue
Block a user