diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index 5473771..96d3534 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -185,12 +185,18 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: return await open_android_netsim_transport(spec) - if scheme == 'unix': + if scheme in ('unix', 'unix-client'): from bumble.transport.unix import open_unix_client_transport assert spec return await open_unix_client_transport(spec) + if scheme == 'unix-server': + from bumble.transport.unix import open_unix_server_transport + + assert spec + return await open_unix_server_transport(spec) + raise TransportSpecError('unknown transport scheme') diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index 5796412..2f0652d 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -77,21 +77,17 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport: # Parse the parameters mode = 'host' - server_host = 'localhost' - server_port = '8554' + server_address = 'localhost:8554' if spec: params = spec.split(',') for param in params: if param.startswith('mode='): mode = param.split('=')[1] - elif ':' in param: - server_host, server_port = param.split(':') else: - raise TransportSpecError('invalid parameter') + server_address = param # Connect to the gRPC server - server_address = f'{server_host}:{server_port}' - logger.debug(f'connecting to gRPC server at {server_address}') + logger.debug('connecting to gRPC server at %s', server_address) channel = grpc.aio.insecure_channel(server_address) service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub] diff --git a/bumble/transport/android_netsim.py b/bumble/transport/android_netsim.py index 48df7da..d683b3d 100644 --- a/bumble/transport/android_netsim.py +++ b/bumble/transport/android_netsim.py @@ -443,7 +443,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport: params = spec.split(',') if spec else [] if params and ':' in params[0]: # Explicit : - host, port_str = params[0].split(':') + host, port_str = params[0].rsplit(':', maxsplit=1) port = int(port_str) params_offset = 1 else: diff --git a/bumble/transport/tcp_client.py b/bumble/transport/tcp_client.py index 0d73e28..623b2d9 100644 --- a/bumble/transport/tcp_client.py +++ b/bumble/transport/tcp_client.py @@ -41,7 +41,7 @@ async def open_tcp_client_transport(spec: str) -> Transport: logger.debug(f'connection lost: {exc}') self.on_transport_lost() - remote_host, remote_port = spec.split(':') + remote_host, remote_port = spec.rsplit(':', maxsplit=1) tcp_transport, packet_source = await asyncio.get_running_loop().create_connection( TcpPacketSource, host=remote_host, diff --git a/bumble/transport/tcp_server.py b/bumble/transport/tcp_server.py index 58e15dc..f6ff69f 100644 --- a/bumble/transport/tcp_server.py +++ b/bumble/transport/tcp_server.py @@ -29,13 +29,6 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- - - -# A pass-through function to ease mock testing. -async def _create_server(*args, **kw_args): - await asyncio.get_running_loop().create_server(*args, **kw_args) - - async def open_tcp_server_transport(spec: str) -> Transport: ''' Open a TCP server transport. @@ -46,13 +39,15 @@ async def open_tcp_server_transport(spec: str) -> Transport: Example: _:9001 ''' - local_host, local_port = spec.split(':') + local_host, local_port = spec.rsplit(':', maxsplit=1) return await _open_tcp_server_transport_impl( host=local_host if local_host != '_' else None, port=int(local_port) ) -async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport: +async def open_tcp_server_transport_with_socket( + sock: socket.socket, +) -> Transport: ''' Open a TCP server transport with an existing socket. @@ -63,8 +58,9 @@ async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transpor async def _open_tcp_server_transport_impl(**kwargs) -> Transport: class TcpServerTransport(Transport): - async def close(self): - await super().close() + def __init__(self, source, sink, server): + self.server = server + super().__init__(source, sink) class TcpServerProtocol(asyncio.BaseProtocol): def __init__(self, packet_source, packet_sink): @@ -102,8 +98,8 @@ async def _open_tcp_server_transport_impl(**kwargs) -> Transport: packet_source = StreamPacketSource() packet_sink = TcpServerPacketSink() - await _create_server( + server = await asyncio.get_running_loop().create_server( lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs ) - return TcpServerTransport(packet_source, packet_sink) + return TcpServerTransport(packet_source, packet_sink, server) diff --git a/bumble/transport/udp.py b/bumble/transport/udp.py index 99f6665..51b10a4 100644 --- a/bumble/transport/udp.py +++ b/bumble/transport/udp.py @@ -51,8 +51,8 @@ async def open_udp_transport(spec: str) -> Transport: self.transport.close() local, remote = spec.split(',') - local_host, local_port = local.split(':') - remote_host, remote_port = remote.split(':') + local_host, local_port = local.rsplit(':', maxsplit=1) + remote_host, remote_port = remote.rsplit(':', maxsplit=1) ( udp_transport, packet_source, diff --git a/bumble/transport/unix.py b/bumble/transport/unix.py index 2e9ae4b..57b60a0 100644 --- a/bumble/transport/unix.py +++ b/bumble/transport/unix.py @@ -54,3 +54,69 @@ async def open_unix_client_transport(spec: str) -> Transport: packet_sink = StreamPacketSink(unix_transport) return Transport(packet_source, packet_sink) + + +# ----------------------------------------------------------------------------- +async def open_unix_server_transport(spec: str) -> Transport: + '''Open a UNIX socket server transport. + + The parameter is the path of unix socket. For abstract socket, the first character + needs to be '@'. + + Example: + * /tmp/hci.socket + * @hci_socket + ''' + # For abstract socket, the first character should be null character. + if spec.startswith('@'): + spec = '\0' + spec[1:] + + class UnixServerTransport(Transport): + def __init__(self, source, sink, server): + self.server = server + super().__init__(source, sink) + + async def close(self): + await super().close() + + class UnixServerProtocol(asyncio.BaseProtocol): + def __init__(self, packet_source, packet_sink): + self.packet_source = packet_source + self.packet_sink = packet_sink + + # Called when a new connection is established + def connection_made(self, transport): + peer_name = transport.get_extra_info('peer_name') + logger.debug('connection from %s', peer_name) + self.packet_sink.transport = transport + + # Called when the client is disconnected + def connection_lost(self, error): + logger.debug('connection lost: %s', error) + self.packet_sink.transport = None + + def eof_received(self): + logger.debug('connection end') + self.packet_sink.transport = None + + # Called when data is received on the socket + def data_received(self, data): + self.packet_source.data_received(data) + + class UnixServerPacketSink: + def __init__(self): + self.transport = None + + def on_packet(self, packet): + if self.transport: + self.transport.write(packet) + else: + logger.debug('no client, dropping packet') + + packet_source = StreamPacketSource() + packet_sink = UnixServerPacketSink() + server = await asyncio.get_running_loop().create_unix_server( + lambda: UnixServerProtocol(packet_source, packet_sink), spec + ) + + return UnixServerTransport(packet_source, packet_sink, server) diff --git a/bumble/transport/ws_server.py b/bumble/transport/ws_server.py index b45bc4e..0a35dda 100644 --- a/bumble/transport/ws_server.py +++ b/bumble/transport/ws_server.py @@ -82,7 +82,7 @@ async def open_ws_server_transport(spec: str) -> Transport: return return await self.connection.send(packet) - local_host, local_port = spec.split(':') + local_host, local_port = spec.rsplit(':', maxsplit=1) transport = WsServerTransport() await transport.serve(local_host, local_port) return transport diff --git a/tests/transport_tcp_server_test.py b/tests/transport_tcp_server_test.py index a5f015d..ff6ff27 100644 --- a/tests/transport_tcp_server_test.py +++ b/tests/transport_tcp_server_test.py @@ -16,8 +16,7 @@ import asyncio import os import pytest import socket -import unittest -from unittest.mock import ANY, patch +from unittest import mock from bumble.transport.tcp_server import ( open_tcp_server_transport, @@ -25,28 +24,23 @@ from bumble.transport.tcp_server import ( ) -class OpenTcpServerTransportTests(unittest.TestCase): - def setUp(self): - self.patcher = patch('bumble.transport.tcp_server._create_server') - self.mock_create_server = self.patcher.start() +async def test_open_with_spec(): + with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m: + await open_tcp_server_transport('localhost:32100') + m.assert_awaited_once_with(mock.ANY, host='localhost', port=32100) - def tearDown(self): - self.patcher.stop() - def test_open_with_spec(self): - asyncio.run(open_tcp_server_transport('localhost:32100')) - self.mock_create_server.assert_awaited_once_with( - ANY, host='localhost', port=32100 - ) +async def test_open_with_port_only_spec(): + with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m: + await open_tcp_server_transport('_:32100') + m.assert_awaited_once_with(mock.ANY, host=None, port=32100) - def test_open_with_port_only_spec(self): - asyncio.run(open_tcp_server_transport('_:32100')) - self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100) - def test_open_with_socket(self): +async def test_open_with_socket(): + with mock.patch.object(asyncio.get_running_loop(), 'create_server') as m: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - asyncio.run(open_tcp_server_transport_with_socket(sock=sock)) - self.mock_create_server.assert_awaited_once_with(ANY, sock=sock) + await open_tcp_server_transport_with_socket(sock=sock) + m.assert_awaited_once_with(mock.ANY, sock=sock) @pytest.mark.skipif( diff --git a/tests/transport_test.py b/tests/transport_test.py index cd3c5f2..58e94ef 100644 --- a/tests/transport_test.py +++ b/tests/transport_test.py @@ -17,9 +17,41 @@ # ----------------------------------------------------------------------------- import random import os +import socket +import sys + +import pytest + +from bumble import controller +from bumble import device +from bumble import hci +from bumble import link +from bumble import transport from bumble.transport.common import PacketParser +# ----------------------------------------------------------------------------- +def _make_controller_from_transport(transport: transport.Transport): + return controller.Controller( + name="server", + host_sink=transport.sink, + host_source=transport.source, + link=link.LocalLink(), + ) + + +# ----------------------------------------------------------------------------- +def _make_device_from_transport( + transport: transport.Transport, address: str = "11:22:33:44:55:66" +): + return device.Device.with_hci( + name="client", + address=hci.Address(address), + hci_sink=transport.sink, + hci_source=transport.source, + ) + + # ----------------------------------------------------------------------------- class Sink: def __init__(self): @@ -71,6 +103,109 @@ def test_parser_extensions(): assert len(sink.packets) == 1 +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "address,", + ("127.0.0.1", "::1"), +) +async def test_tcp_connection(address): + server_transport = await transport.open_transport(f"tcp-server:{address}:0") + port = server_transport.server.sockets[0].getsockname()[1] + _make_controller_from_transport(server_transport) + + client_transport = await transport.open_transport(f"tcp-client:{address}:{port}") + client_device = _make_device_from_transport(client_transport) + await client_device.power_on() + + await client_transport.close() + await server_transport.close() + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "address, family", + (("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)), +) +async def test_udp_connection(address, family): + # Pick empty ports + ports = [] + for _ in range(2): + sock = socket.socket(family=family, type=socket.SOCK_DGRAM) + sock.bind((address, 0)) + ports.append(sock.getsockname()[1]) + sock.close() + + server_transport = await transport.open_transport( + f"udp:{address}:{ports[0]},{address}:{ports[1]}" + ) + _make_controller_from_transport(server_transport) + + client_transport = await transport.open_transport( + f"udp:{address}:{ports[1]},{address}:{ports[0]}" + ) + client_device = _make_device_from_transport(client_transport) + await client_device.power_on() + + await client_transport.close() + await server_transport.close() + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "server_address, client_address", + ( + ("127.0.0.1", "ws://127.0.0.1"), + ("::1", "ws://[::1]"), + ), +) +async def test_ws_connection(server_address, client_address): + server_transport = await transport.open_transport(f"ws-server:{server_address}:0") + port = server_transport.server.sockets[0].getsockname()[1] + _make_controller_from_transport(server_transport) + + client_transport = await transport.open_transport( + f"ws-client:{client_address}:{port}" + ) + client_device = _make_device_from_transport(client_transport) + await client_device.power_on() + + await client_transport.close() + await server_transport.close() + + +# ----------------------------------------------------------------------------- +@pytest.mark.skipif( + sys.platform != 'linux', reason='Unix socket is only fully supported on Linux' +) +async def test_unix_connection_file(tmpdir): + path = str(tmpdir / 'bumble.sock') + server_transport = await transport.open_transport(f"unix-server:{path}") + _make_controller_from_transport(server_transport) + + client_transport = await transport.open_transport(f"unix-client:{path}") + client_device = _make_device_from_transport(client_transport) + await client_device.power_on() + + await client_transport.close() + await server_transport.close() + + +# ----------------------------------------------------------------------------- +@pytest.mark.skipif( + sys.platform != 'linux', reason='Unix socket is only fully supported on Linux' +) +async def test_unix_connection_abstract(): + server_transport = await transport.open_transport("unix-server:@bumble.test.sock") + _make_controller_from_transport(server_transport) + + client_transport = await transport.open_transport("unix-client:@bumble.test.sock") + client_device = _make_device_from_transport(client_transport) + await client_device.power_on() + + await client_transport.close() + await server_transport.close() + + # ----------------------------------------------------------------------------- if __name__ == '__main__': test_parser()