diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 9e9b836..2854593 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -1477,8 +1477,23 @@ class Protocol(utils.EventEmitter): handler = getattr(self, handler_name, None) if handler: try: - response = handler(message) - self.send_message(transaction_label, response) + result = handler(message) + if asyncio.iscoroutine(result): + + async def wait_and_send() -> None: + try: + response = await result + if response: + self.send_message(transaction_label, response) + except Exception: + logger.exception( + color("!!! Exception in handler:", "red") + ) + + utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send()) + else: + if result: + self.send_message(transaction_label, result) except Exception: logger.exception(color("!!! Exception in handler:", "red")) else: @@ -1559,7 +1574,7 @@ class Protocol(utils.EventEmitter): async def send_command(self, command: Message): # TODO: support timeouts # Send the command - (transaction_label, transaction_result) = await self.start_transaction() + transaction_label, transaction_result = await self.start_transaction() self.send_message(transaction_label, command) # Wait for the response @@ -1624,14 +1639,14 @@ class Protocol(utils.EventEmitter): async def abort(self, seid: int) -> Abort_Response: return await self.send_command(Abort_Command(seid)) - def on_discover_command(self, command: Discover_Command) -> Message | None: + async def on_discover_command(self, command: Discover_Command) -> Message | None: endpoint_infos = [ EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep) for endpoint in self.local_endpoints ] return Discover_Response(endpoint_infos) - def on_get_capabilities_command( + async def on_get_capabilities_command( self, command: Get_Capabilities_Command ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) @@ -1640,7 +1655,7 @@ class Protocol(utils.EventEmitter): return Get_Capabilities_Response(endpoint.capabilities) - def on_get_all_capabilities_command( + async def on_get_all_capabilities_command( self, command: Get_All_Capabilities_Command ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) @@ -1649,7 +1664,7 @@ class Protocol(utils.EventEmitter): return Get_All_Capabilities_Response(endpoint.capabilities) - def on_set_configuration_command( + async def on_set_configuration_command( self, command: Set_Configuration_Command ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) @@ -1664,10 +1679,10 @@ class Protocol(utils.EventEmitter): stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid)) self.streams[command.acp_seid] = stream - result = stream.on_set_configuration_command(command.capabilities) + result = await stream.on_set_configuration_command(command.capabilities) return result or Set_Configuration_Response() - def on_get_configuration_command( + async def on_get_configuration_command( self, command: Get_Configuration_Command ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) @@ -1676,29 +1691,31 @@ class Protocol(utils.EventEmitter): if endpoint.stream is None: return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR) - return endpoint.stream.on_get_configuration_command() + return await endpoint.stream.on_get_configuration_command() - def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None: + async def on_reconfigure_command( + self, command: Reconfigure_Command + ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) if endpoint is None: return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR) if endpoint.stream is None: return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR) - result = endpoint.stream.on_reconfigure_command(command.capabilities) + result = await endpoint.stream.on_reconfigure_command(command.capabilities) return result or Reconfigure_Response() - def on_open_command(self, command: Open_Command) -> Message | None: + async def on_open_command(self, command: Open_Command) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) if endpoint is None: return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR) if endpoint.stream is None: return Open_Reject(AVDTP_BAD_STATE_ERROR) - result = endpoint.stream.on_open_command() + result = await endpoint.stream.on_open_command() return result or Open_Response() - def on_start_command(self, command: Start_Command) -> Message | None: + async def on_start_command(self, command: Start_Command) -> Message | None: for seid in command.acp_seids: endpoint = self.get_local_endpoint_by_seid(seid) if endpoint is None: @@ -1712,12 +1729,12 @@ class Protocol(utils.EventEmitter): endpoint = self.get_local_endpoint_by_seid(seid) if not endpoint or not endpoint.stream: raise InvalidStateError("Should already be checked!") - if (result := endpoint.stream.on_start_command()) is not None: + if (result := await endpoint.stream.on_start_command()) is not None: return result return Start_Response() - def on_suspend_command(self, command: Suspend_Command) -> Message | None: + async def on_suspend_command(self, command: Suspend_Command) -> Message | None: for seid in command.acp_seids: endpoint = self.get_local_endpoint_by_seid(seid) if endpoint is None: @@ -1731,45 +1748,47 @@ class Protocol(utils.EventEmitter): endpoint = self.get_local_endpoint_by_seid(seid) if not endpoint or not endpoint.stream: raise InvalidStateError("Should already be checked!") - if (result := endpoint.stream.on_suspend_command()) is not None: + if (result := await endpoint.stream.on_suspend_command()) is not None: return result return Suspend_Response() - def on_close_command(self, command: Close_Command) -> Message | None: + async def on_close_command(self, command: Close_Command) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) if endpoint is None: return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR) if endpoint.stream is None: return Close_Reject(AVDTP_BAD_STATE_ERROR) - result = endpoint.stream.on_close_command() + result = await endpoint.stream.on_close_command() return result or Close_Response() - def on_abort_command(self, command: Abort_Command) -> Message | None: + async def on_abort_command(self, command: Abort_Command) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) if endpoint is None or endpoint.stream is None: return Abort_Response() - endpoint.stream.on_abort_command() + await endpoint.stream.on_abort_command() return Abort_Response() - def on_security_control_command( + async def on_security_control_command( self, command: Security_Control_Command ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) if endpoint is None: return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR) - result = endpoint.on_security_control_command(command.data) + result = await endpoint.on_security_control_command(command.data) return result or Security_Control_Response() - def on_delayreport_command(self, command: DelayReport_Command) -> Message | None: + async def on_delayreport_command( + self, command: DelayReport_Command + ) -> Message | None: endpoint = self.get_local_endpoint_by_seid(command.acp_seid) if endpoint is None: return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR) - result = endpoint.on_delayreport_command(command.delay) + result = await endpoint.on_delayreport_command(command.delay) return result or DelayReport_Response() @@ -1932,20 +1951,20 @@ class Stream: self.change_state(State.IDLE) - def on_set_configuration_command( + async def on_set_configuration_command( self, configuration: Iterable[ServiceCapabilities] ) -> Message | None: if self.state != State.IDLE: return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR) - result = self.local_endpoint.on_set_configuration_command(configuration) + result = await self.local_endpoint.on_set_configuration_command(configuration) if result is not None: return result self.change_state(State.CONFIGURED) return None - def on_get_configuration_command(self) -> Message | None: + async def on_get_configuration_command(self) -> Message | None: if self.state not in ( State.CONFIGURED, State.OPEN, @@ -1953,25 +1972,25 @@ class Stream: ): return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR) - return self.local_endpoint.on_get_configuration_command() + return await self.local_endpoint.on_get_configuration_command() - def on_reconfigure_command( + async def on_reconfigure_command( self, configuration: Iterable[ServiceCapabilities] ) -> Message | None: if self.state != State.OPEN: return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR) - result = self.local_endpoint.on_reconfigure_command(configuration) + result = await self.local_endpoint.on_reconfigure_command(configuration) if result is not None: return result return None - def on_open_command(self) -> Message | None: + async def on_open_command(self) -> Message | None: if self.state != State.CONFIGURED: return Open_Reject(AVDTP_BAD_STATE_ERROR) - result = self.local_endpoint.on_open_command() + result = await self.local_endpoint.on_open_command() if result is not None: return result @@ -1981,7 +2000,7 @@ class Stream: self.change_state(State.OPEN) return None - def on_start_command(self) -> Message | None: + async def on_start_command(self) -> Message | None: if self.state != State.OPEN: return Open_Reject(AVDTP_BAD_STATE_ERROR) @@ -1990,29 +2009,29 @@ class Stream: logger.warning('received start command before RTP channel establishment') return Open_Reject(AVDTP_BAD_STATE_ERROR) - result = self.local_endpoint.on_start_command() + result = await self.local_endpoint.on_start_command() if result is not None: return result self.change_state(State.STREAMING) return None - def on_suspend_command(self) -> Message | None: + async def on_suspend_command(self) -> Message | None: if self.state != State.STREAMING: return Open_Reject(AVDTP_BAD_STATE_ERROR) - result = self.local_endpoint.on_suspend_command() + result = await self.local_endpoint.on_suspend_command() if result is not None: return result self.change_state(State.OPEN) return None - def on_close_command(self) -> Message | None: + async def on_close_command(self) -> Message | None: if self.state not in (State.OPEN, State.STREAMING): return Open_Reject(AVDTP_BAD_STATE_ERROR) - result = self.local_endpoint.on_close_command() + result = await self.local_endpoint.on_close_command() if result is not None: return result @@ -2027,7 +2046,8 @@ class Stream: return None - def on_abort_command(self) -> Message | None: + async def on_abort_command(self) -> Message | None: + await self.local_endpoint.on_abort_command() if self.rtp_channel is None: # No need to wait self.change_state(State.IDLE) @@ -2179,13 +2199,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter): async def close(self) -> None: """[Source Only] Handles when receiving close command.""" - def on_reconfigure_command( + async def on_reconfigure_command( self, command: Iterable[ServiceCapabilities] ) -> Message | None: del command # unused. return None - def on_set_configuration_command( + async def on_set_configuration_command( self, configuration: Iterable[ServiceCapabilities] ) -> Message | None: logger.debug( @@ -2196,34 +2216,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter): self.emit(self.EVENT_CONFIGURATION) return None - def on_get_configuration_command(self) -> Message | None: + async def on_get_configuration_command(self) -> Message | None: return Get_Configuration_Response(self.configuration) - def on_open_command(self) -> Message | None: + async def on_open_command(self) -> Message | None: self.emit(self.EVENT_OPEN) return None - def on_start_command(self) -> Message | None: + async def on_start_command(self) -> Message | None: self.emit(self.EVENT_START) return None - def on_suspend_command(self) -> Message | None: + async def on_suspend_command(self) -> Message | None: self.emit(self.EVENT_SUSPEND) return None - def on_close_command(self) -> Message | None: + async def on_close_command(self) -> Message | None: self.emit(self.EVENT_CLOSE) return None - def on_abort_command(self) -> Message | None: + async def on_abort_command(self) -> Message | None: self.emit(self.EVENT_ABORT) return None - def on_delayreport_command(self, delay: int) -> Message | None: + async def on_delayreport_command(self, delay: int) -> Message | None: self.emit(self.EVENT_DELAY_REPORT, delay) return None - def on_security_control_command(self, data: bytes) -> Message | None: + async def on_security_control_command(self, data: bytes) -> Message | None: self.emit(self.EVENT_SECURITY_CONTROL, data) return None @@ -2275,13 +2295,13 @@ class LocalSource(LocalStreamEndPoint): self.emit(self.EVENT_STOP) @override - def on_start_command(self) -> Message | None: - asyncio.create_task(self.start()) + async def on_start_command(self) -> Message | None: + await self.start() return None @override - def on_suspend_command(self) -> Message | None: - asyncio.create_task(self.stop()) + async def on_suspend_command(self) -> Message | None: + await self.stop() return None