forked from auracaster/bumble_mirror
Merge pull request #288 from zxzxwu/l2cap_states
L2CAP: Refactor states to enums
This commit is contained in:
215
bumble/l2cap.py
215
bumble/l2cap.py
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
|
||||
@@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Channel(EventEmitter):
|
||||
# States
|
||||
CLOSED = 0x00
|
||||
WAIT_CONNECT = 0x01
|
||||
WAIT_CONNECT_RSP = 0x02
|
||||
OPEN = 0x03
|
||||
WAIT_DISCONNECT = 0x04
|
||||
WAIT_CREATE = 0x05
|
||||
WAIT_CREATE_RSP = 0x06
|
||||
WAIT_MOVE = 0x07
|
||||
WAIT_MOVE_RSP = 0x08
|
||||
WAIT_MOVE_CONFIRM = 0x09
|
||||
WAIT_CONFIRM_RSP = 0x0A
|
||||
class State(enum.IntEnum):
|
||||
# States
|
||||
CLOSED = 0x00
|
||||
WAIT_CONNECT = 0x01
|
||||
WAIT_CONNECT_RSP = 0x02
|
||||
OPEN = 0x03
|
||||
WAIT_DISCONNECT = 0x04
|
||||
WAIT_CREATE = 0x05
|
||||
WAIT_CREATE_RSP = 0x06
|
||||
WAIT_MOVE = 0x07
|
||||
WAIT_MOVE_RSP = 0x08
|
||||
WAIT_MOVE_CONFIRM = 0x09
|
||||
WAIT_CONFIRM_RSP = 0x0A
|
||||
|
||||
# CONFIG substates
|
||||
WAIT_CONFIG = 0x10
|
||||
WAIT_SEND_CONFIG = 0x11
|
||||
WAIT_CONFIG_REQ_RSP = 0x12
|
||||
WAIT_CONFIG_RSP = 0x13
|
||||
WAIT_CONFIG_REQ = 0x14
|
||||
WAIT_IND_FINAL_RSP = 0x15
|
||||
WAIT_FINAL_RSP = 0x16
|
||||
WAIT_CONTROL_IND = 0x17
|
||||
|
||||
STATE_NAMES = {
|
||||
CLOSED: 'CLOSED',
|
||||
WAIT_CONNECT: 'WAIT_CONNECT',
|
||||
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
|
||||
OPEN: 'OPEN',
|
||||
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
|
||||
WAIT_CREATE: 'WAIT_CREATE',
|
||||
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
|
||||
WAIT_MOVE: 'WAIT_MOVE',
|
||||
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
|
||||
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
|
||||
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
|
||||
WAIT_CONFIG: 'WAIT_CONFIG',
|
||||
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
|
||||
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
|
||||
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
|
||||
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
|
||||
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
|
||||
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
|
||||
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
|
||||
}
|
||||
# CONFIG substates
|
||||
WAIT_CONFIG = 0x10
|
||||
WAIT_SEND_CONFIG = 0x11
|
||||
WAIT_CONFIG_REQ_RSP = 0x12
|
||||
WAIT_CONFIG_RSP = 0x13
|
||||
WAIT_CONFIG_REQ = 0x14
|
||||
WAIT_IND_FINAL_RSP = 0x15
|
||||
WAIT_FINAL_RSP = 0x16
|
||||
WAIT_CONTROL_IND = 0x17
|
||||
|
||||
connection_result: Optional[asyncio.Future[None]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
response: Optional[asyncio.Future[bytes]]
|
||||
sink: Optional[Callable[[bytes], Any]]
|
||||
state: int
|
||||
state: State
|
||||
connection: Connection
|
||||
|
||||
def __init__(
|
||||
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.signaling_cid = signaling_cid
|
||||
self.state = Channel.CLOSED
|
||||
self.state = self.State.CLOSED
|
||||
self.mtu = mtu
|
||||
self.psm = psm
|
||||
self.source_cid = source_cid
|
||||
@@ -751,10 +731,8 @@ class Channel(EventEmitter):
|
||||
self.disconnection_result = None
|
||||
self.sink = None
|
||||
|
||||
def change_state(self, new_state: int) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
|
||||
)
|
||||
def _change_state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||
self.state = new_state
|
||||
|
||||
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.state != Channel.OPEN:
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('channel not open')
|
||||
|
||||
self.response = asyncio.get_running_loop().create_future()
|
||||
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
|
||||
)
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.state != Channel.CLOSED:
|
||||
if self.state != self.State.CLOSED:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
# Check that we can start a new connection
|
||||
if self.connection_result:
|
||||
raise RuntimeError('connection already pending')
|
||||
|
||||
self.change_state(Channel.WAIT_CONNECT_RSP)
|
||||
self._change_state(self.State.WAIT_CONNECT_RSP)
|
||||
self.send_control_frame(
|
||||
L2CAP_Connection_Request(
|
||||
identifier=self.manager.next_identifier(self.connection),
|
||||
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
|
||||
self.connection_result = None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.state != Channel.OPEN:
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
self.change_state(Channel.WAIT_DISCONNECT)
|
||||
self._change_state(self.State.WAIT_DISCONNECT)
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Request(
|
||||
identifier=self.manager.next_identifier(self.connection),
|
||||
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
|
||||
return await self.disconnection_result
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.state == self.OPEN:
|
||||
self.change_state(self.CLOSED)
|
||||
if self.state == self.State.OPEN:
|
||||
self._change_state(self.State.CLOSED)
|
||||
self.emit('close')
|
||||
|
||||
def send_configure_request(self) -> None:
|
||||
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
|
||||
|
||||
def on_connection_request(self, request) -> None:
|
||||
self.destination_cid = request.source_cid
|
||||
self.change_state(Channel.WAIT_CONNECT)
|
||||
self._change_state(self.State.WAIT_CONNECT)
|
||||
self.send_control_frame(
|
||||
L2CAP_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
|
||||
status=0x0000,
|
||||
)
|
||||
)
|
||||
self.change_state(Channel.WAIT_CONFIG)
|
||||
self._change_state(self.State.WAIT_CONFIG)
|
||||
self.send_configure_request()
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
|
||||
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
|
||||
|
||||
def on_connection_response(self, response):
|
||||
if self.state != Channel.WAIT_CONNECT_RSP:
|
||||
if self.state != self.State.WAIT_CONNECT_RSP:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
|
||||
self.destination_cid = response.destination_cid
|
||||
self.change_state(Channel.WAIT_CONFIG)
|
||||
self._change_state(self.State.WAIT_CONFIG)
|
||||
self.send_configure_request()
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
|
||||
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
|
||||
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
|
||||
pass
|
||||
else:
|
||||
self.change_state(Channel.CLOSED)
|
||||
self._change_state(self.State.CLOSED)
|
||||
self.connection_result.set_exception(
|
||||
ProtocolError(
|
||||
response.result,
|
||||
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
|
||||
|
||||
def on_configure_request(self, request) -> None:
|
||||
if self.state not in (
|
||||
Channel.WAIT_CONFIG,
|
||||
Channel.WAIT_CONFIG_REQ,
|
||||
Channel.WAIT_CONFIG_REQ_RSP,
|
||||
self.State.WAIT_CONFIG,
|
||||
self.State.WAIT_CONFIG_REQ,
|
||||
self.State.WAIT_CONFIG_REQ_RSP,
|
||||
):
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
|
||||
options=request.options, # TODO: don't accept everything blindly
|
||||
)
|
||||
)
|
||||
if self.state == Channel.WAIT_CONFIG:
|
||||
self.change_state(Channel.WAIT_SEND_CONFIG)
|
||||
if self.state == self.State.WAIT_CONFIG:
|
||||
self._change_state(self.State.WAIT_SEND_CONFIG)
|
||||
self.send_configure_request()
|
||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
||||
elif self.state == Channel.WAIT_CONFIG_REQ:
|
||||
self.change_state(Channel.OPEN)
|
||||
self._change_state(self.State.WAIT_CONFIG_RSP)
|
||||
elif self.state == self.State.WAIT_CONFIG_REQ:
|
||||
self._change_state(self.State.OPEN)
|
||||
if self.connection_result:
|
||||
self.connection_result.set_result(None)
|
||||
self.connection_result = None
|
||||
self.emit('open')
|
||||
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
||||
elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
|
||||
self._change_state(self.State.WAIT_CONFIG_RSP)
|
||||
|
||||
def on_configure_response(self, response) -> None:
|
||||
if response.result == L2CAP_Configure_Response.SUCCESS:
|
||||
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ)
|
||||
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
|
||||
self.change_state(Channel.OPEN)
|
||||
if self.state == self.State.WAIT_CONFIG_REQ_RSP:
|
||||
self._change_state(self.State.WAIT_CONFIG_REQ)
|
||||
elif self.state in (
|
||||
self.State.WAIT_CONFIG_RSP,
|
||||
self.State.WAIT_CONTROL_IND,
|
||||
):
|
||||
self._change_state(self.State.OPEN)
|
||||
if self.connection_result:
|
||||
self.connection_result.set_result(None)
|
||||
self.connection_result = None
|
||||
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
|
||||
# TODO: decide how to fail gracefully
|
||||
|
||||
def on_disconnection_request(self, request) -> None:
|
||||
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
|
||||
if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Response(
|
||||
identifier=request.identifier,
|
||||
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
|
||||
source_cid=request.source_cid,
|
||||
)
|
||||
)
|
||||
self.change_state(Channel.CLOSED)
|
||||
self._change_state(self.State.CLOSED)
|
||||
self.emit('close')
|
||||
self.manager.on_channel_closed(self)
|
||||
else:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
|
||||
def on_disconnection_response(self, response) -> None:
|
||||
if self.state != Channel.WAIT_DISCONNECT:
|
||||
if self.state != self.State.WAIT_DISCONNECT:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
|
||||
logger.warning('unexpected source or destination CID')
|
||||
return
|
||||
|
||||
self.change_state(Channel.CLOSED)
|
||||
self._change_state(self.State.CLOSED)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
|
||||
f'Channel({self.source_cid}->{self.destination_cid}, '
|
||||
f'PSM={self.psm}, '
|
||||
f'MTU={self.mtu}, '
|
||||
f'state={Channel.STATE_NAMES[self.state]})'
|
||||
f'state={self.state.name})'
|
||||
)
|
||||
|
||||
|
||||
@@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
LE Credit-based Connection Oriented Channel
|
||||
"""
|
||||
|
||||
INIT = 0
|
||||
CONNECTED = 1
|
||||
CONNECTING = 2
|
||||
DISCONNECTING = 3
|
||||
DISCONNECTED = 4
|
||||
CONNECTION_ERROR = 5
|
||||
|
||||
STATE_NAMES = {
|
||||
INIT: 'INIT',
|
||||
CONNECTED: 'CONNECTED',
|
||||
CONNECTING: 'CONNECTING',
|
||||
DISCONNECTING: 'DISCONNECTING',
|
||||
DISCONNECTED: 'DISCONNECTED',
|
||||
CONNECTION_ERROR: 'CONNECTION_ERROR',
|
||||
}
|
||||
class State(enum.IntEnum):
|
||||
INIT = 0
|
||||
CONNECTED = 1
|
||||
CONNECTING = 2
|
||||
DISCONNECTING = 3
|
||||
DISCONNECTED = 4
|
||||
CONNECTION_ERROR = 5
|
||||
|
||||
out_queue: Deque[bytes]
|
||||
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
out_sdu: Optional[bytes]
|
||||
state: int
|
||||
state: State
|
||||
connection: Connection
|
||||
|
||||
@staticmethod
|
||||
def state_name(state: int) -> str:
|
||||
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: ChannelManager,
|
||||
@@ -1083,19 +1052,17 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.drained.set()
|
||||
|
||||
if connected:
|
||||
self.state = LeConnectionOrientedChannel.CONNECTED
|
||||
self.state = self.State.CONNECTED
|
||||
else:
|
||||
self.state = LeConnectionOrientedChannel.INIT
|
||||
self.state = self.State.INIT
|
||||
|
||||
def change_state(self, new_state: int) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
|
||||
)
|
||||
def _change_state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||
self.state = new_state
|
||||
|
||||
if new_state == self.CONNECTED:
|
||||
if new_state == self.State.CONNECTED:
|
||||
self.emit('open')
|
||||
elif new_state == self.DISCONNECTED:
|
||||
elif new_state == self.State.DISCONNECTED:
|
||||
self.emit('close')
|
||||
|
||||
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
|
||||
async def connect(self) -> LeConnectionOrientedChannel:
|
||||
# Check that we're in the right state
|
||||
if self.state != self.INIT:
|
||||
if self.state != self.State.INIT:
|
||||
raise InvalidStateError('not in a connectable state')
|
||||
|
||||
# Check that we can start a new connection
|
||||
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise RuntimeError('too many concurrent connection requests')
|
||||
|
||||
self.change_state(self.CONNECTING)
|
||||
self._change_state(self.State.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
identifier=identifier,
|
||||
le_psm=self.le_psm,
|
||||
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
# Check that we're connected
|
||||
if self.state != self.CONNECTED:
|
||||
if self.state != self.State.CONNECTED:
|
||||
raise InvalidStateError('not connected')
|
||||
|
||||
self.change_state(self.DISCONNECTING)
|
||||
self._change_state(self.State.DISCONNECTING)
|
||||
self.flush_output()
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Request(
|
||||
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
return await self.disconnection_result
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.state == self.CONNECTED:
|
||||
self.change_state(self.DISCONNECTED)
|
||||
if self.state == self.State.CONNECTED:
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.sink is None:
|
||||
logger.warning('received pdu without a sink')
|
||||
return
|
||||
|
||||
if self.state != self.CONNECTED:
|
||||
if self.state != self.State.CONNECTED:
|
||||
logger.warning('received PDU while not connected, dropping')
|
||||
|
||||
# Manage the peer credits
|
||||
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.credits = response.initial_credits
|
||||
self.connected = True
|
||||
self.connection_result.set_result(self)
|
||||
self.change_state(self.CONNECTED)
|
||||
self._change_state(self.State.CONNECTED)
|
||||
else:
|
||||
self.connection_result.set_exception(
|
||||
ProtocolError(
|
||||
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
),
|
||||
)
|
||||
)
|
||||
self.change_state(self.CONNECTION_ERROR)
|
||||
self._change_state(self.State.CONNECTION_ERROR)
|
||||
|
||||
# Cleanup
|
||||
self.connection_result = None
|
||||
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
source_cid=request.source_cid,
|
||||
)
|
||||
)
|
||||
self.change_state(self.DISCONNECTED)
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.flush_output()
|
||||
|
||||
def on_disconnection_response(self, response) -> None:
|
||||
if self.state != self.DISCONNECTING:
|
||||
if self.state != self.State.DISCONNECTING:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
logger.warning('unexpected source or destination CID')
|
||||
return
|
||||
|
||||
self.change_state(self.DISCONNECTED)
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
return
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
if self.state != self.CONNECTED:
|
||||
if self.state != self.State.CONNECTED:
|
||||
logger.warning('not connected, dropping data')
|
||||
return
|
||||
|
||||
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'CoC({self.source_cid}->{self.destination_cid}, '
|
||||
f'State={self.state_name(self.state)}, '
|
||||
f'State={self.state.name}, '
|
||||
f'PSM={self.le_psm}, '
|
||||
f'MTU={self.mtu}/{self.peer_mtu}, '
|
||||
f'MPS={self.mps}/{self.peer_mps}, '
|
||||
|
||||
Reference in New Issue
Block a user