mirror of
https://github.com/google/bumble.git
synced 2026-04-18 00:45:32 +00:00
@@ -185,12 +185,18 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
if scheme == 'unix':
|
||||
if scheme in ('unix', 'unix-client'):
|
||||
from bumble.transport.unix import open_unix_client_transport
|
||||
|
||||
assert spec
|
||||
return await open_unix_client_transport(spec)
|
||||
|
||||
if scheme == 'unix-server':
|
||||
from bumble.transport.unix import open_unix_server_transport
|
||||
|
||||
assert spec
|
||||
return await open_unix_server_transport(spec)
|
||||
|
||||
raise TransportSpecError('unknown transport scheme')
|
||||
|
||||
|
||||
|
||||
@@ -77,21 +77,17 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
|
||||
# Parse the parameters
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = '8554'
|
||||
server_address = 'localhost:8554'
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
mode = param.split('=')[1]
|
||||
elif ':' in param:
|
||||
server_host, server_port = param.split(':')
|
||||
else:
|
||||
raise TransportSpecError('invalid parameter')
|
||||
server_address = param
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'connecting to gRPC server at {server_address}')
|
||||
logger.debug('connecting to gRPC server at %s', server_address)
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
|
||||
|
||||
@@ -145,8 +145,6 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
|
||||
async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int, options: dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise TransportSpecError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
@@ -168,14 +166,16 @@ async def open_android_netsim_controller_transport(
|
||||
await self.pump_loop()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('Pump task canceled')
|
||||
self.done.set_result(None)
|
||||
if not self.done.done():
|
||||
self.done.set_result(None)
|
||||
|
||||
async def pump_loop(self):
|
||||
while True:
|
||||
request = await self.context.read()
|
||||
if request == grpc.aio.EOF:
|
||||
logger.debug('End of request stream')
|
||||
self.done.set_result(None)
|
||||
if not self.done.done():
|
||||
self.done.set_result(None)
|
||||
return
|
||||
|
||||
# If we're not initialized yet, wait for a init packet.
|
||||
@@ -220,6 +220,8 @@ async def open_android_netsim_controller_transport(
|
||||
async def wait_for_termination(self):
|
||||
await self.done
|
||||
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
|
||||
class Server(PacketStreamerServicer, ParserSource):
|
||||
def __init__(self):
|
||||
PacketStreamerServicer.__init__(self)
|
||||
@@ -230,8 +232,8 @@ async def open_android_netsim_controller_transport(
|
||||
# a server listening on that port, we get an exception.
|
||||
self.grpc_server = grpc.aio.server(options=(('grpc.so_reuseport', 0),))
|
||||
add_PacketStreamerServicer_to_server(self, self.grpc_server)
|
||||
self.grpc_server.add_insecure_port(f'{server_host}:{server_port}')
|
||||
logger.debug(f'gRPC server listening on {server_host}:{server_port}')
|
||||
self.port = self.grpc_server.add_insecure_port(server_address)
|
||||
logger.debug('gRPC server listening on %s', server_address)
|
||||
|
||||
async def start(self):
|
||||
logger.debug('Starting gRPC server')
|
||||
@@ -443,7 +445,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
params = spec.split(',') if spec else []
|
||||
if params and ':' in params[0]:
|
||||
# Explicit <host>:<port>
|
||||
host, port_str = params[0].split(':')
|
||||
host, port_str = params[0].rsplit(':', maxsplit=1)
|
||||
port = int(port_str)
|
||||
params_offset = 1
|
||||
else:
|
||||
|
||||
@@ -41,7 +41,7 @@ async def open_tcp_client_transport(spec: str) -> Transport:
|
||||
logger.debug(f'connection lost: {exc}')
|
||||
self.on_transport_lost()
|
||||
|
||||
remote_host, remote_port = spec.split(':')
|
||||
remote_host, remote_port = spec.rsplit(':', maxsplit=1)
|
||||
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
|
||||
TcpPacketSource,
|
||||
host=remote_host,
|
||||
|
||||
@@ -29,13 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# A pass-through function to ease mock testing.
|
||||
async def _create_server(*args, **kw_args):
|
||||
await asyncio.get_running_loop().create_server(*args, **kw_args)
|
||||
|
||||
|
||||
async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport.
|
||||
@@ -46,13 +39,15 @@ async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
|
||||
Example: _:9001
|
||||
'''
|
||||
local_host, local_port = spec.split(':')
|
||||
local_host, local_port = spec.rsplit(':', maxsplit=1)
|
||||
return await _open_tcp_server_transport_impl(
|
||||
host=local_host if local_host != '_' else None, port=int(local_port)
|
||||
)
|
||||
|
||||
|
||||
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
|
||||
async def open_tcp_server_transport_with_socket(
|
||||
sock: socket.socket,
|
||||
) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport with an existing socket.
|
||||
|
||||
@@ -63,8 +58,9 @@ async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transpor
|
||||
|
||||
async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
|
||||
class TcpServerTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
def __init__(self, source, sink, server):
|
||||
self.server = server
|
||||
super().__init__(source, sink)
|
||||
|
||||
class TcpServerProtocol(asyncio.BaseProtocol):
|
||||
def __init__(self, packet_source, packet_sink):
|
||||
@@ -102,8 +98,8 @@ async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
|
||||
|
||||
packet_source = StreamPacketSource()
|
||||
packet_sink = TcpServerPacketSink()
|
||||
await _create_server(
|
||||
server = await asyncio.get_running_loop().create_server(
|
||||
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
|
||||
)
|
||||
|
||||
return TcpServerTransport(packet_source, packet_sink)
|
||||
return TcpServerTransport(packet_source, packet_sink, server)
|
||||
|
||||
@@ -51,8 +51,8 @@ async def open_udp_transport(spec: str) -> Transport:
|
||||
self.transport.close()
|
||||
|
||||
local, remote = spec.split(',')
|
||||
local_host, local_port = local.split(':')
|
||||
remote_host, remote_port = remote.split(':')
|
||||
local_host, local_port = local.rsplit(':', maxsplit=1)
|
||||
remote_host, remote_port = remote.rsplit(':', maxsplit=1)
|
||||
(
|
||||
udp_transport,
|
||||
packet_source,
|
||||
|
||||
@@ -54,3 +54,69 @@ async def open_unix_client_transport(spec: str) -> Transport:
|
||||
packet_sink = StreamPacketSink(unix_transport)
|
||||
|
||||
return Transport(packet_source, packet_sink)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_unix_server_transport(spec: str) -> Transport:
|
||||
'''Open a UNIX socket server transport.
|
||||
|
||||
The parameter is the path of unix socket. For abstract socket, the first character
|
||||
needs to be '@'.
|
||||
|
||||
Example:
|
||||
* /tmp/hci.socket
|
||||
* @hci_socket
|
||||
'''
|
||||
# For abstract socket, the first character should be null character.
|
||||
if spec.startswith('@'):
|
||||
spec = '\0' + spec[1:]
|
||||
|
||||
class UnixServerTransport(Transport):
|
||||
def __init__(self, source, sink, server):
|
||||
self.server = server
|
||||
super().__init__(source, sink)
|
||||
|
||||
async def close(self):
|
||||
await super().close()
|
||||
|
||||
class UnixServerProtocol(asyncio.BaseProtocol):
|
||||
def __init__(self, packet_source, packet_sink):
|
||||
self.packet_source = packet_source
|
||||
self.packet_sink = packet_sink
|
||||
|
||||
# Called when a new connection is established
|
||||
def connection_made(self, transport):
|
||||
peer_name = transport.get_extra_info('peer_name')
|
||||
logger.debug('connection from %s', peer_name)
|
||||
self.packet_sink.transport = transport
|
||||
|
||||
# Called when the client is disconnected
|
||||
def connection_lost(self, error):
|
||||
logger.debug('connection lost: %s', error)
|
||||
self.packet_sink.transport = None
|
||||
|
||||
def eof_received(self):
|
||||
logger.debug('connection end')
|
||||
self.packet_sink.transport = None
|
||||
|
||||
# Called when data is received on the socket
|
||||
def data_received(self, data):
|
||||
self.packet_source.data_received(data)
|
||||
|
||||
class UnixServerPacketSink:
|
||||
def __init__(self):
|
||||
self.transport = None
|
||||
|
||||
def on_packet(self, packet):
|
||||
if self.transport:
|
||||
self.transport.write(packet)
|
||||
else:
|
||||
logger.debug('no client, dropping packet')
|
||||
|
||||
packet_source = StreamPacketSource()
|
||||
packet_sink = UnixServerPacketSink()
|
||||
server = await asyncio.get_running_loop().create_unix_server(
|
||||
lambda: UnixServerProtocol(packet_source, packet_sink), spec
|
||||
)
|
||||
|
||||
return UnixServerTransport(packet_source, packet_sink, server)
|
||||
|
||||
@@ -82,7 +82,7 @@ async def open_ws_server_transport(spec: str) -> Transport:
|
||||
return
|
||||
return await self.connection.send(packet)
|
||||
|
||||
local_host, local_port = spec.split(':')
|
||||
local_host, local_port = spec.rsplit(':', maxsplit=1)
|
||||
transport = WsServerTransport()
|
||||
await transport.serve(local_host, local_port)
|
||||
return transport
|
||||
|
||||
Reference in New Issue
Block a user