Merge pull request #288 from zxzxwu/l2cap_states

L2CAP: Refactor states to enums
This commit is contained in:
zxzxwu
2023-09-21 15:42:21 +08:00
committed by GitHub

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import logging
import struct
@@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
class Channel(EventEmitter):
# States
CLOSED = 0x00
WAIT_CONNECT = 0x01
WAIT_CONNECT_RSP = 0x02
OPEN = 0x03
WAIT_DISCONNECT = 0x04
WAIT_CREATE = 0x05
WAIT_CREATE_RSP = 0x06
WAIT_MOVE = 0x07
WAIT_MOVE_RSP = 0x08
WAIT_MOVE_CONFIRM = 0x09
WAIT_CONFIRM_RSP = 0x0A
class State(enum.IntEnum):
# States
CLOSED = 0x00
WAIT_CONNECT = 0x01
WAIT_CONNECT_RSP = 0x02
OPEN = 0x03
WAIT_DISCONNECT = 0x04
WAIT_CREATE = 0x05
WAIT_CREATE_RSP = 0x06
WAIT_MOVE = 0x07
WAIT_MOVE_RSP = 0x08
WAIT_MOVE_CONFIRM = 0x09
WAIT_CONFIRM_RSP = 0x0A
# CONFIG substates
WAIT_CONFIG = 0x10
WAIT_SEND_CONFIG = 0x11
WAIT_CONFIG_REQ_RSP = 0x12
WAIT_CONFIG_RSP = 0x13
WAIT_CONFIG_REQ = 0x14
WAIT_IND_FINAL_RSP = 0x15
WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17
STATE_NAMES = {
CLOSED: 'CLOSED',
WAIT_CONNECT: 'WAIT_CONNECT',
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
OPEN: 'OPEN',
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
WAIT_CREATE: 'WAIT_CREATE',
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
WAIT_MOVE: 'WAIT_MOVE',
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
WAIT_CONFIG: 'WAIT_CONFIG',
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
}
# CONFIG substates
WAIT_CONFIG = 0x10
WAIT_SEND_CONFIG = 0x11
WAIT_CONFIG_REQ_RSP = 0x12
WAIT_CONFIG_RSP = 0x13
WAIT_CONFIG_REQ = 0x14
WAIT_IND_FINAL_RSP = 0x15
WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17
connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
state: int
state: State
connection: Connection
def __init__(
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
self.manager = manager
self.connection = connection
self.signaling_cid = signaling_cid
self.state = Channel.CLOSED
self.state = self.State.CLOSED
self.mtu = mtu
self.psm = psm
self.source_cid = source_cid
@@ -751,10 +731,8 @@ class Channel(EventEmitter):
self.disconnection_result = None
self.sink = None
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
def _change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
if self.state != Channel.OPEN:
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
)
async def connect(self) -> None:
if self.state != Channel.CLOSED:
if self.state != self.State.CLOSED:
raise InvalidStateError('invalid state')
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
self.change_state(Channel.WAIT_CONNECT_RSP)
self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame(
L2CAP_Connection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
self.connection_result = None
async def disconnect(self) -> None:
if self.state != Channel.OPEN:
if self.state != self.State.OPEN:
raise InvalidStateError('invalid state')
self.change_state(Channel.WAIT_DISCONNECT)
self._change_state(self.State.WAIT_DISCONNECT)
self.send_control_frame(
L2CAP_Disconnection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
if self.state == self.OPEN:
self.change_state(self.CLOSED)
if self.state == self.State.OPEN:
self._change_state(self.State.CLOSED)
self.emit('close')
def send_configure_request(self) -> None:
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid
self.change_state(Channel.WAIT_CONNECT)
self._change_state(self.State.WAIT_CONNECT)
self.send_control_frame(
L2CAP_Connection_Response(
identifier=request.identifier,
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
status=0x0000,
)
)
self.change_state(Channel.WAIT_CONFIG)
self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
def on_connection_response(self, response):
if self.state != Channel.WAIT_CONNECT_RSP:
if self.state != self.State.WAIT_CONNECT_RSP:
logger.warning(color('invalid state', 'red'))
return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid
self.change_state(Channel.WAIT_CONFIG)
self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
pass
else:
self.change_state(Channel.CLOSED)
self._change_state(self.State.CLOSED)
self.connection_result.set_exception(
ProtocolError(
response.result,
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
def on_configure_request(self, request) -> None:
if self.state not in (
Channel.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ,
Channel.WAIT_CONFIG_REQ_RSP,
self.State.WAIT_CONFIG,
self.State.WAIT_CONFIG_REQ,
self.State.WAIT_CONFIG_REQ_RSP,
):
logger.warning(color('invalid state', 'red'))
return
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
options=request.options, # TODO: don't accept everything blindly
)
)
if self.state == Channel.WAIT_CONFIG:
self.change_state(Channel.WAIT_SEND_CONFIG)
if self.state == self.State.WAIT_CONFIG:
self._change_state(self.State.WAIT_SEND_CONFIG)
self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_RSP)
elif self.state == Channel.WAIT_CONFIG_REQ:
self.change_state(Channel.OPEN)
self._change_state(self.State.WAIT_CONFIG_RSP)
elif self.state == self.State.WAIT_CONFIG_REQ:
self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
self.emit('open')
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_RSP)
elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
self._change_state(self.State.WAIT_CONFIG_RSP)
def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ)
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
self.change_state(Channel.OPEN)
if self.state == self.State.WAIT_CONFIG_REQ_RSP:
self._change_state(self.State.WAIT_CONFIG_REQ)
elif self.state in (
self.State.WAIT_CONFIG_RSP,
self.State.WAIT_CONTROL_IND,
):
self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request) -> None:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
source_cid=request.source_cid,
)
)
self.change_state(Channel.CLOSED)
self._change_state(self.State.CLOSED)
self.emit('close')
self.manager.on_channel_closed(self)
else:
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response) -> None:
if self.state != Channel.WAIT_DISCONNECT:
if self.state != self.State.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red'))
return
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
self.change_state(Channel.CLOSED)
self._change_state(self.State.CLOSED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
f'state={Channel.STATE_NAMES[self.state]})'
f'state={self.state.name})'
)
@@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter):
LE Credit-based Connection Oriented Channel
"""
INIT = 0
CONNECTED = 1
CONNECTING = 2
DISCONNECTING = 3
DISCONNECTED = 4
CONNECTION_ERROR = 5
STATE_NAMES = {
INIT: 'INIT',
CONNECTED: 'CONNECTED',
CONNECTING: 'CONNECTING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
CONNECTION_ERROR: 'CONNECTION_ERROR',
}
class State(enum.IntEnum):
INIT = 0
CONNECTED = 1
CONNECTING = 2
DISCONNECTING = 3
DISCONNECTED = 4
CONNECTION_ERROR = 5
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
state: int
state: State
connection: Connection
@staticmethod
def state_name(state: int) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__(
self,
manager: ChannelManager,
@@ -1083,19 +1052,17 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set()
if connected:
self.state = LeConnectionOrientedChannel.CONNECTED
self.state = self.State.CONNECTED
else:
self.state = LeConnectionOrientedChannel.INIT
self.state = self.State.INIT
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
def _change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
if new_state == self.CONNECTED:
if new_state == self.State.CONNECTED:
self.emit('open')
elif new_state == self.DISCONNECTED:
elif new_state == self.State.DISCONNECTED:
self.emit('close')
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
if self.state != self.INIT:
if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
self.change_state(self.CONNECTING)
self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier,
le_psm=self.le_psm,
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
async def disconnect(self) -> None:
# Check that we're connected
if self.state != self.CONNECTED:
if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected')
self.change_state(self.DISCONNECTING)
self._change_state(self.State.DISCONNECTING)
self.flush_output()
self.send_control_frame(
L2CAP_Disconnection_Request(
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED)
if self.state == self.State.CONNECTED:
self._change_state(self.State.DISCONNECTED)
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
logger.warning('received pdu without a sink')
return
if self.state != self.CONNECTED:
if self.state != self.State.CONNECTED:
logger.warning('received PDU while not connected, dropping')
# Manage the peer credits
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.credits = response.initial_credits
self.connected = True
self.connection_result.set_result(self)
self.change_state(self.CONNECTED)
self._change_state(self.State.CONNECTED)
else:
self.connection_result.set_exception(
ProtocolError(
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
),
)
)
self.change_state(self.CONNECTION_ERROR)
self._change_state(self.State.CONNECTION_ERROR)
# Cleanup
self.connection_result = None
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
source_cid=request.source_cid,
)
)
self.change_state(self.DISCONNECTED)
self._change_state(self.State.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response) -> None:
if self.state != self.DISCONNECTING:
if self.state != self.State.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
self.change_state(self.DISCONNECTED)
self._change_state(self.State.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
return
def write(self, data: bytes) -> None:
if self.state != self.CONNECTED:
if self.state != self.State.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __str__(self) -> str:
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, '
f'State={self.state.name}, '
f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, '