Merge pull request #920 from zxzxwu/avdtp

AVDTP: Make all handlers async
This commit is contained in:
Josh Wu
2026-05-07 17:39:24 +08:00
committed by GitHub

View File

@@ -1477,8 +1477,23 @@ class Protocol(utils.EventEmitter):
handler = getattr(self, handler_name, None) handler = getattr(self, handler_name, None)
if handler: if handler:
try: try:
response = handler(message) result = handler(message)
self.send_message(transaction_label, response) if asyncio.iscoroutine(result):
async def wait_and_send() -> None:
try:
response = await result
if response:
self.send_message(transaction_label, response)
except Exception:
logger.exception(
color("!!! Exception in handler:", "red")
)
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
else:
if result:
self.send_message(transaction_label, result)
except Exception: except Exception:
logger.exception(color("!!! Exception in handler:", "red")) logger.exception(color("!!! Exception in handler:", "red"))
else: else:
@@ -1559,7 +1574,7 @@ class Protocol(utils.EventEmitter):
async def send_command(self, command: Message): async def send_command(self, command: Message):
# TODO: support timeouts # TODO: support timeouts
# Send the command # Send the command
(transaction_label, transaction_result) = await self.start_transaction() transaction_label, transaction_result = await self.start_transaction()
self.send_message(transaction_label, command) self.send_message(transaction_label, command)
# Wait for the response # Wait for the response
@@ -1624,14 +1639,14 @@ class Protocol(utils.EventEmitter):
async def abort(self, seid: int) -> Abort_Response: async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid)) return await self.send_command(Abort_Command(seid))
def on_discover_command(self, command: Discover_Command) -> Message | None: async def on_discover_command(self, command: Discover_Command) -> Message | None:
endpoint_infos = [ endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep) EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints for endpoint in self.local_endpoints
] ]
return Discover_Response(endpoint_infos) return Discover_Response(endpoint_infos)
def on_get_capabilities_command( async def on_get_capabilities_command(
self, command: Get_Capabilities_Command self, command: Get_Capabilities_Command
) -> Message | None: ) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1640,7 +1655,7 @@ class Protocol(utils.EventEmitter):
return Get_Capabilities_Response(endpoint.capabilities) return Get_Capabilities_Response(endpoint.capabilities)
def on_get_all_capabilities_command( async def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command self, command: Get_All_Capabilities_Command
) -> Message | None: ) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1649,7 +1664,7 @@ class Protocol(utils.EventEmitter):
return Get_All_Capabilities_Response(endpoint.capabilities) return Get_All_Capabilities_Response(endpoint.capabilities)
def on_set_configuration_command( async def on_set_configuration_command(
self, command: Set_Configuration_Command self, command: Set_Configuration_Command
) -> Message | None: ) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1664,10 +1679,10 @@ class Protocol(utils.EventEmitter):
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid)) stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
self.streams[command.acp_seid] = stream self.streams[command.acp_seid] = stream
result = stream.on_set_configuration_command(command.capabilities) result = await stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response() return result or Set_Configuration_Response()
def on_get_configuration_command( async def on_get_configuration_command(
self, command: Get_Configuration_Command self, command: Get_Configuration_Command
) -> Message | None: ) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1676,29 +1691,31 @@ class Protocol(utils.EventEmitter):
if endpoint.stream is None: if endpoint.stream is None:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR) return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return endpoint.stream.on_get_configuration_command() return await endpoint.stream.on_get_configuration_command()
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None: async def on_reconfigure_command(
self, command: Reconfigure_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR) return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None: if endpoint.stream is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR) return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_reconfigure_command(command.capabilities) result = await endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response() return result or Reconfigure_Response()
def on_open_command(self, command: Open_Command) -> Message | None: async def on_open_command(self, command: Open_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None: if endpoint.stream is None:
return Open_Reject(AVDTP_BAD_STATE_ERROR) return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_open_command() result = await endpoint.stream.on_open_command()
return result or Open_Response() return result or Open_Response()
def on_start_command(self, command: Start_Command) -> Message | None: async def on_start_command(self, command: Start_Command) -> Message | None:
for seid in command.acp_seids: for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None: if endpoint is None:
@@ -1712,12 +1729,12 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream: if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!") raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_start_command()) is not None: if (result := await endpoint.stream.on_start_command()) is not None:
return result return result
return Start_Response() return Start_Response()
def on_suspend_command(self, command: Suspend_Command) -> Message | None: async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
for seid in command.acp_seids: for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None: if endpoint is None:
@@ -1731,45 +1748,47 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream: if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!") raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_suspend_command()) is not None: if (result := await endpoint.stream.on_suspend_command()) is not None:
return result return result
return Suspend_Response() return Suspend_Response()
def on_close_command(self, command: Close_Command) -> Message | None: async def on_close_command(self, command: Close_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None: if endpoint.stream is None:
return Close_Reject(AVDTP_BAD_STATE_ERROR) return Close_Reject(AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_close_command() result = await endpoint.stream.on_close_command()
return result or Close_Response() return result or Close_Response()
def on_abort_command(self, command: Abort_Command) -> Message | None: async def on_abort_command(self, command: Abort_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None: if endpoint is None or endpoint.stream is None:
return Abort_Response() return Abort_Response()
endpoint.stream.on_abort_command() await endpoint.stream.on_abort_command()
return Abort_Response() return Abort_Response()
def on_security_control_command( async def on_security_control_command(
self, command: Security_Control_Command self, command: Security_Control_Command
) -> Message | None: ) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_security_control_command(command.data) result = await endpoint.on_security_control_command(command.data)
return result or Security_Control_Response() return result or Security_Control_Response()
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None: async def on_delayreport_command(
self, command: DelayReport_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR) return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_delayreport_command(command.delay) result = await endpoint.on_delayreport_command(command.delay)
return result or DelayReport_Response() return result or DelayReport_Response()
@@ -1932,20 +1951,20 @@ class Stream:
self.change_state(State.IDLE) self.change_state(State.IDLE)
def on_set_configuration_command( async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities] self, configuration: Iterable[ServiceCapabilities]
) -> Message | None: ) -> Message | None:
if self.state != State.IDLE: if self.state != State.IDLE:
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR) return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_set_configuration_command(configuration) result = await self.local_endpoint.on_set_configuration_command(configuration)
if result is not None: if result is not None:
return result return result
self.change_state(State.CONFIGURED) self.change_state(State.CONFIGURED)
return None return None
def on_get_configuration_command(self) -> Message | None: async def on_get_configuration_command(self) -> Message | None:
if self.state not in ( if self.state not in (
State.CONFIGURED, State.CONFIGURED,
State.OPEN, State.OPEN,
@@ -1953,25 +1972,25 @@ class Stream:
): ):
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR) return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command() return await self.local_endpoint.on_get_configuration_command()
def on_reconfigure_command( async def on_reconfigure_command(
self, configuration: Iterable[ServiceCapabilities] self, configuration: Iterable[ServiceCapabilities]
) -> Message | None: ) -> Message | None:
if self.state != State.OPEN: if self.state != State.OPEN:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR) return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_reconfigure_command(configuration) result = await self.local_endpoint.on_reconfigure_command(configuration)
if result is not None: if result is not None:
return result return result
return None return None
def on_open_command(self) -> Message | None: async def on_open_command(self) -> Message | None:
if self.state != State.CONFIGURED: if self.state != State.CONFIGURED:
return Open_Reject(AVDTP_BAD_STATE_ERROR) return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_open_command() result = await self.local_endpoint.on_open_command()
if result is not None: if result is not None:
return result return result
@@ -1981,7 +2000,7 @@ class Stream:
self.change_state(State.OPEN) self.change_state(State.OPEN)
return None return None
def on_start_command(self) -> Message | None: async def on_start_command(self) -> Message | None:
if self.state != State.OPEN: if self.state != State.OPEN:
return Open_Reject(AVDTP_BAD_STATE_ERROR) return Open_Reject(AVDTP_BAD_STATE_ERROR)
@@ -1990,29 +2009,29 @@ class Stream:
logger.warning('received start command before RTP channel establishment') logger.warning('received start command before RTP channel establishment')
return Open_Reject(AVDTP_BAD_STATE_ERROR) return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_start_command() result = await self.local_endpoint.on_start_command()
if result is not None: if result is not None:
return result return result
self.change_state(State.STREAMING) self.change_state(State.STREAMING)
return None return None
def on_suspend_command(self) -> Message | None: async def on_suspend_command(self) -> Message | None:
if self.state != State.STREAMING: if self.state != State.STREAMING:
return Open_Reject(AVDTP_BAD_STATE_ERROR) return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_suspend_command() result = await self.local_endpoint.on_suspend_command()
if result is not None: if result is not None:
return result return result
self.change_state(State.OPEN) self.change_state(State.OPEN)
return None return None
def on_close_command(self) -> Message | None: async def on_close_command(self) -> Message | None:
if self.state not in (State.OPEN, State.STREAMING): if self.state not in (State.OPEN, State.STREAMING):
return Open_Reject(AVDTP_BAD_STATE_ERROR) return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_close_command() result = await self.local_endpoint.on_close_command()
if result is not None: if result is not None:
return result return result
@@ -2027,7 +2046,8 @@ class Stream:
return None return None
def on_abort_command(self) -> Message | None: async def on_abort_command(self) -> Message | None:
await self.local_endpoint.on_abort_command()
if self.rtp_channel is None: if self.rtp_channel is None:
# No need to wait # No need to wait
self.change_state(State.IDLE) self.change_state(State.IDLE)
@@ -2179,13 +2199,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
async def close(self) -> None: async def close(self) -> None:
"""[Source Only] Handles when receiving close command.""" """[Source Only] Handles when receiving close command."""
def on_reconfigure_command( async def on_reconfigure_command(
self, command: Iterable[ServiceCapabilities] self, command: Iterable[ServiceCapabilities]
) -> Message | None: ) -> Message | None:
del command # unused. del command # unused.
return None return None
def on_set_configuration_command( async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities] self, configuration: Iterable[ServiceCapabilities]
) -> Message | None: ) -> Message | None:
logger.debug( logger.debug(
@@ -2196,34 +2216,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
self.emit(self.EVENT_CONFIGURATION) self.emit(self.EVENT_CONFIGURATION)
return None return None
def on_get_configuration_command(self) -> Message | None: async def on_get_configuration_command(self) -> Message | None:
return Get_Configuration_Response(self.configuration) return Get_Configuration_Response(self.configuration)
def on_open_command(self) -> Message | None: async def on_open_command(self) -> Message | None:
self.emit(self.EVENT_OPEN) self.emit(self.EVENT_OPEN)
return None return None
def on_start_command(self) -> Message | None: async def on_start_command(self) -> Message | None:
self.emit(self.EVENT_START) self.emit(self.EVENT_START)
return None return None
def on_suspend_command(self) -> Message | None: async def on_suspend_command(self) -> Message | None:
self.emit(self.EVENT_SUSPEND) self.emit(self.EVENT_SUSPEND)
return None return None
def on_close_command(self) -> Message | None: async def on_close_command(self) -> Message | None:
self.emit(self.EVENT_CLOSE) self.emit(self.EVENT_CLOSE)
return None return None
def on_abort_command(self) -> Message | None: async def on_abort_command(self) -> Message | None:
self.emit(self.EVENT_ABORT) self.emit(self.EVENT_ABORT)
return None return None
def on_delayreport_command(self, delay: int) -> Message | None: async def on_delayreport_command(self, delay: int) -> Message | None:
self.emit(self.EVENT_DELAY_REPORT, delay) self.emit(self.EVENT_DELAY_REPORT, delay)
return None return None
def on_security_control_command(self, data: bytes) -> Message | None: async def on_security_control_command(self, data: bytes) -> Message | None:
self.emit(self.EVENT_SECURITY_CONTROL, data) self.emit(self.EVENT_SECURITY_CONTROL, data)
return None return None
@@ -2275,13 +2295,13 @@ class LocalSource(LocalStreamEndPoint):
self.emit(self.EVENT_STOP) self.emit(self.EVENT_STOP)
@override @override
def on_start_command(self) -> Message | None: async def on_start_command(self) -> Message | None:
asyncio.create_task(self.start()) await self.start()
return None return None
@override @override
def on_suspend_command(self) -> Message | None: async def on_suspend_command(self) -> Message | None:
asyncio.create_task(self.stop()) await self.stop()
return None return None