Merge pull request #288 from zxzxwu/l2cap_states

L2CAP: Refactor states to enums
This commit is contained in:
zxzxwu
2023-09-21 15:42:21 +08:00
committed by GitHub

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import enum
import logging import logging
import struct import struct
@@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Channel(EventEmitter): class Channel(EventEmitter):
# States class State(enum.IntEnum):
CLOSED = 0x00 # States
WAIT_CONNECT = 0x01 CLOSED = 0x00
WAIT_CONNECT_RSP = 0x02 WAIT_CONNECT = 0x01
OPEN = 0x03 WAIT_CONNECT_RSP = 0x02
WAIT_DISCONNECT = 0x04 OPEN = 0x03
WAIT_CREATE = 0x05 WAIT_DISCONNECT = 0x04
WAIT_CREATE_RSP = 0x06 WAIT_CREATE = 0x05
WAIT_MOVE = 0x07 WAIT_CREATE_RSP = 0x06
WAIT_MOVE_RSP = 0x08 WAIT_MOVE = 0x07
WAIT_MOVE_CONFIRM = 0x09 WAIT_MOVE_RSP = 0x08
WAIT_CONFIRM_RSP = 0x0A WAIT_MOVE_CONFIRM = 0x09
WAIT_CONFIRM_RSP = 0x0A
# CONFIG substates # CONFIG substates
WAIT_CONFIG = 0x10 WAIT_CONFIG = 0x10
WAIT_SEND_CONFIG = 0x11 WAIT_SEND_CONFIG = 0x11
WAIT_CONFIG_REQ_RSP = 0x12 WAIT_CONFIG_REQ_RSP = 0x12
WAIT_CONFIG_RSP = 0x13 WAIT_CONFIG_RSP = 0x13
WAIT_CONFIG_REQ = 0x14 WAIT_CONFIG_REQ = 0x14
WAIT_IND_FINAL_RSP = 0x15 WAIT_IND_FINAL_RSP = 0x15
WAIT_FINAL_RSP = 0x16 WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17 WAIT_CONTROL_IND = 0x17
STATE_NAMES = {
CLOSED: 'CLOSED',
WAIT_CONNECT: 'WAIT_CONNECT',
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
OPEN: 'OPEN',
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
WAIT_CREATE: 'WAIT_CREATE',
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
WAIT_MOVE: 'WAIT_MOVE',
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
WAIT_CONFIG: 'WAIT_CONFIG',
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
}
connection_result: Optional[asyncio.Future[None]] connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]] response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]] sink: Optional[Callable[[bytes], Any]]
state: int state: State
connection: Connection connection: Connection
def __init__( def __init__(
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
self.manager = manager self.manager = manager
self.connection = connection self.connection = connection
self.signaling_cid = signaling_cid self.signaling_cid = signaling_cid
self.state = Channel.CLOSED self.state = self.State.CLOSED
self.mtu = mtu self.mtu = mtu
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
@@ -751,10 +731,8 @@ class Channel(EventEmitter):
self.disconnection_result = None self.disconnection_result = None
self.sink = None self.sink = None
def change_state(self, new_state: int) -> None: def _change_state(self, new_state: State) -> None:
logger.debug( logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
self.state = new_state self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
# Check that there isn't already a request pending # Check that there isn't already a request pending
if self.response: if self.response:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.state != Channel.OPEN: if self.state != self.State.OPEN:
raise InvalidStateError('channel not open') raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future() self.response = asyncio.get_running_loop().create_future()
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
) )
async def connect(self) -> None: async def connect(self) -> None:
if self.state != Channel.CLOSED: if self.state != self.State.CLOSED:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
# Check that we can start a new connection # Check that we can start a new connection
if self.connection_result: if self.connection_result:
raise RuntimeError('connection already pending') raise RuntimeError('connection already pending')
self.change_state(Channel.WAIT_CONNECT_RSP) self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame( self.send_control_frame(
L2CAP_Connection_Request( L2CAP_Connection_Request(
identifier=self.manager.next_identifier(self.connection), identifier=self.manager.next_identifier(self.connection),
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
self.connection_result = None self.connection_result = None
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self.state != Channel.OPEN: if self.state != self.State.OPEN:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
self.change_state(Channel.WAIT_DISCONNECT) self._change_state(self.State.WAIT_DISCONNECT)
self.send_control_frame( self.send_control_frame(
L2CAP_Disconnection_Request( L2CAP_Disconnection_Request(
identifier=self.manager.next_identifier(self.connection), identifier=self.manager.next_identifier(self.connection),
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
return await self.disconnection_result return await self.disconnection_result
def abort(self) -> None: def abort(self) -> None:
if self.state == self.OPEN: if self.state == self.State.OPEN:
self.change_state(self.CLOSED) self._change_state(self.State.CLOSED)
self.emit('close') self.emit('close')
def send_configure_request(self) -> None: def send_configure_request(self) -> None:
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
def on_connection_request(self, request) -> None: def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid self.destination_cid = request.source_cid
self.change_state(Channel.WAIT_CONNECT) self._change_state(self.State.WAIT_CONNECT)
self.send_control_frame( self.send_control_frame(
L2CAP_Connection_Response( L2CAP_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
status=0x0000, status=0x0000,
) )
) )
self.change_state(Channel.WAIT_CONFIG) self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request() self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP) self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
def on_connection_response(self, response): def on_connection_response(self, response):
if self.state != Channel.WAIT_CONNECT_RSP: if self.state != self.State.WAIT_CONNECT_RSP:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid self.destination_cid = response.destination_cid
self.change_state(Channel.WAIT_CONFIG) self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request() self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP) self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING: elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
pass pass
else: else:
self.change_state(Channel.CLOSED) self._change_state(self.State.CLOSED)
self.connection_result.set_exception( self.connection_result.set_exception(
ProtocolError( ProtocolError(
response.result, response.result,
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
def on_configure_request(self, request) -> None: def on_configure_request(self, request) -> None:
if self.state not in ( if self.state not in (
Channel.WAIT_CONFIG, self.State.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ, self.State.WAIT_CONFIG_REQ,
Channel.WAIT_CONFIG_REQ_RSP, self.State.WAIT_CONFIG_REQ_RSP,
): ):
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
options=request.options, # TODO: don't accept everything blindly options=request.options, # TODO: don't accept everything blindly
) )
) )
if self.state == Channel.WAIT_CONFIG: if self.state == self.State.WAIT_CONFIG:
self.change_state(Channel.WAIT_SEND_CONFIG) self._change_state(self.State.WAIT_SEND_CONFIG)
self.send_configure_request() self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_RSP) self._change_state(self.State.WAIT_CONFIG_RSP)
elif self.state == Channel.WAIT_CONFIG_REQ: elif self.state == self.State.WAIT_CONFIG_REQ:
self.change_state(Channel.OPEN) self._change_state(self.State.OPEN)
if self.connection_result: if self.connection_result:
self.connection_result.set_result(None) self.connection_result.set_result(None)
self.connection_result = None self.connection_result = None
self.emit('open') self.emit('open')
elif self.state == Channel.WAIT_CONFIG_REQ_RSP: elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_RSP) self._change_state(self.State.WAIT_CONFIG_RSP)
def on_configure_response(self, response) -> None: def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS: if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP: if self.state == self.State.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ) self._change_state(self.State.WAIT_CONFIG_REQ)
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND): elif self.state in (
self.change_state(Channel.OPEN) self.State.WAIT_CONFIG_RSP,
self.State.WAIT_CONTROL_IND,
):
self._change_state(self.State.OPEN)
if self.connection_result: if self.connection_result:
self.connection_result.set_result(None) self.connection_result.set_result(None)
self.connection_result = None self.connection_result = None
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
# TODO: decide how to fail gracefully # TODO: decide how to fail gracefully
def on_disconnection_request(self, request) -> None: def on_disconnection_request(self, request) -> None:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT): if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
self.send_control_frame( self.send_control_frame(
L2CAP_Disconnection_Response( L2CAP_Disconnection_Response(
identifier=request.identifier, identifier=request.identifier,
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
source_cid=request.source_cid, source_cid=request.source_cid,
) )
) )
self.change_state(Channel.CLOSED) self._change_state(self.State.CLOSED)
self.emit('close') self.emit('close')
self.manager.on_channel_closed(self) self.manager.on_channel_closed(self)
else: else:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response) -> None: def on_disconnection_response(self, response) -> None:
if self.state != Channel.WAIT_DISCONNECT: if self.state != self.State.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
logger.warning('unexpected source or destination CID') logger.warning('unexpected source or destination CID')
return return
self.change_state(Channel.CLOSED) self._change_state(self.State.CLOSED)
if self.disconnection_result: if self.disconnection_result:
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
f'Channel({self.source_cid}->{self.destination_cid}, ' f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, ' f'PSM={self.psm}, '
f'MTU={self.mtu}, ' f'MTU={self.mtu}, '
f'state={Channel.STATE_NAMES[self.state]})' f'state={self.state.name})'
) )
@@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter):
LE Credit-based Connection Oriented Channel LE Credit-based Connection Oriented Channel
""" """
INIT = 0 class State(enum.IntEnum):
CONNECTED = 1 INIT = 0
CONNECTING = 2 CONNECTED = 1
DISCONNECTING = 3 CONNECTING = 2
DISCONNECTED = 4 DISCONNECTING = 3
CONNECTION_ERROR = 5 DISCONNECTED = 4
CONNECTION_ERROR = 5
STATE_NAMES = {
INIT: 'INIT',
CONNECTED: 'CONNECTED',
CONNECTING: 'CONNECTING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
CONNECTION_ERROR: 'CONNECTION_ERROR',
}
out_queue: Deque[bytes] out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes] out_sdu: Optional[bytes]
state: int state: State
connection: Connection connection: Connection
@staticmethod
def state_name(state: int) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__( def __init__(
self, self,
manager: ChannelManager, manager: ChannelManager,
@@ -1083,19 +1052,17 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set() self.drained.set()
if connected: if connected:
self.state = LeConnectionOrientedChannel.CONNECTED self.state = self.State.CONNECTED
else: else:
self.state = LeConnectionOrientedChannel.INIT self.state = self.State.INIT
def change_state(self, new_state: int) -> None: def _change_state(self, new_state: State) -> None:
logger.debug( logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state self.state = new_state
if new_state == self.CONNECTED: if new_state == self.State.CONNECTED:
self.emit('open') self.emit('open')
elif new_state == self.DISCONNECTED: elif new_state == self.State.DISCONNECTED:
self.emit('close') self.emit('close')
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
async def connect(self) -> LeConnectionOrientedChannel: async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state # Check that we're in the right state
if self.state != self.INIT: if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state') raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection # Check that we can start a new connection
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if identifier in self.manager.le_coc_requests: if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests') raise RuntimeError('too many concurrent connection requests')
self.change_state(self.CONNECTING) self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request( request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier, identifier=identifier,
le_psm=self.le_psm, le_psm=self.le_psm,
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
async def disconnect(self) -> None: async def disconnect(self) -> None:
# Check that we're connected # Check that we're connected
if self.state != self.CONNECTED: if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected') raise InvalidStateError('not connected')
self.change_state(self.DISCONNECTING) self._change_state(self.State.DISCONNECTING)
self.flush_output() self.flush_output()
self.send_control_frame( self.send_control_frame(
L2CAP_Disconnection_Request( L2CAP_Disconnection_Request(
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
return await self.disconnection_result return await self.disconnection_result
def abort(self) -> None: def abort(self) -> None:
if self.state == self.CONNECTED: if self.state == self.State.CONNECTED:
self.change_state(self.DISCONNECTED) self._change_state(self.State.DISCONNECTED)
def on_pdu(self, pdu: bytes) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.sink is None: if self.sink is None:
logger.warning('received pdu without a sink') logger.warning('received pdu without a sink')
return return
if self.state != self.CONNECTED: if self.state != self.State.CONNECTED:
logger.warning('received PDU while not connected, dropping') logger.warning('received PDU while not connected, dropping')
# Manage the peer credits # Manage the peer credits
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.credits = response.initial_credits self.credits = response.initial_credits
self.connected = True self.connected = True
self.connection_result.set_result(self) self.connection_result.set_result(self)
self.change_state(self.CONNECTED) self._change_state(self.State.CONNECTED)
else: else:
self.connection_result.set_exception( self.connection_result.set_exception(
ProtocolError( ProtocolError(
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
), ),
) )
) )
self.change_state(self.CONNECTION_ERROR) self._change_state(self.State.CONNECTION_ERROR)
# Cleanup # Cleanup
self.connection_result = None self.connection_result = None
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
source_cid=request.source_cid, source_cid=request.source_cid,
) )
) )
self.change_state(self.DISCONNECTED) self._change_state(self.State.DISCONNECTED)
self.flush_output() self.flush_output()
def on_disconnection_response(self, response) -> None: def on_disconnection_response(self, response) -> None:
if self.state != self.DISCONNECTING: if self.state != self.State.DISCONNECTING:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
logger.warning('unexpected source or destination CID') logger.warning('unexpected source or destination CID')
return return
self.change_state(self.DISCONNECTED) self._change_state(self.State.DISCONNECTED)
if self.disconnection_result: if self.disconnection_result:
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
return return
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
if self.state != self.CONNECTED: if self.state != self.State.CONNECTED:
logger.warning('not connected, dropping data') logger.warning('not connected, dropping data')
return return
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f'CoC({self.source_cid}->{self.destination_cid}, ' f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, ' f'State={self.state.name}, '
f'PSM={self.le_psm}, ' f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, ' f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, ' f'MPS={self.mps}/{self.peer_mps}, '