forked from auracaster/bumble_mirror
Handle SMP_Security_Request
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import grpc
|
import grpc
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -27,8 +28,8 @@ from bumble.core import (
|
|||||||
)
|
)
|
||||||
from bumble.device import Connection as BumbleConnection, Device
|
from bumble.device import Connection as BumbleConnection, Device
|
||||||
from bumble.hci import HCI_Error
|
from bumble.hci import HCI_Error
|
||||||
|
from bumble.utils import EventWatcher
|
||||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||||
from contextlib import suppress
|
|
||||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||||
@@ -294,23 +295,35 @@ class SecurityService(SecurityServicer):
|
|||||||
try:
|
try:
|
||||||
self.log.debug('Pair...')
|
self.log.debug('Pair...')
|
||||||
|
|
||||||
if (
|
security_result = asyncio.get_running_loop().create_future()
|
||||||
connection.transport == BT_LE_TRANSPORT
|
|
||||||
and connection.role == BT_PERIPHERAL_ROLE
|
|
||||||
):
|
|
||||||
wait_for_security: asyncio.Future[
|
|
||||||
bool
|
|
||||||
] = asyncio.get_running_loop().create_future()
|
|
||||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
|
||||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
|
||||||
|
|
||||||
connection.request_pairing()
|
with contextlib.closing(EventWatcher()) as watcher:
|
||||||
|
|
||||||
await wait_for_security
|
@watcher.on(connection, 'pairing')
|
||||||
else:
|
def on_pairing(*_: Any) -> None:
|
||||||
await connection.pair()
|
security_result.set_result('success')
|
||||||
|
|
||||||
self.log.debug('Paired')
|
@watcher.on(connection, 'pairing_failure')
|
||||||
|
def on_pairing_failure(*_: Any) -> None:
|
||||||
|
security_result.set_result('pairing_failure')
|
||||||
|
|
||||||
|
@watcher.on(connection, 'disconnection')
|
||||||
|
def on_disconnection(*_: Any) -> None:
|
||||||
|
security_result.set_result('connection_died')
|
||||||
|
|
||||||
|
if (
|
||||||
|
connection.transport == BT_LE_TRANSPORT
|
||||||
|
and connection.role == BT_PERIPHERAL_ROLE
|
||||||
|
):
|
||||||
|
connection.request_pairing()
|
||||||
|
else:
|
||||||
|
await connection.pair()
|
||||||
|
|
||||||
|
result = await security_result
|
||||||
|
|
||||||
|
self.log.debug(f'Pairing session complete, status={result}')
|
||||||
|
if result != 'success':
|
||||||
|
return SecureResponse(**{result: empty_pb2.Empty()})
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self.log.warning("Connection died during encryption")
|
self.log.warning("Connection died during encryption")
|
||||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||||
@@ -369,6 +382,7 @@ class SecurityService(SecurityServicer):
|
|||||||
str
|
str
|
||||||
] = asyncio.get_running_loop().create_future()
|
] = asyncio.get_running_loop().create_future()
|
||||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||||
|
pair_task: Optional[asyncio.Future[None]] = None
|
||||||
|
|
||||||
async def authenticate() -> None:
|
async def authenticate() -> None:
|
||||||
assert connection
|
assert connection
|
||||||
@@ -415,6 +429,10 @@ class SecurityService(SecurityServicer):
|
|||||||
if authenticate_task is None:
|
if authenticate_task is None:
|
||||||
authenticate_task = asyncio.create_task(authenticate())
|
authenticate_task = asyncio.create_task(authenticate())
|
||||||
|
|
||||||
|
def pair(*_: Any) -> None:
|
||||||
|
if self.need_pairing(connection, level):
|
||||||
|
pair_task = asyncio.create_task(connection.pair())
|
||||||
|
|
||||||
listeners: Dict[str, Callable[..., None]] = {
|
listeners: Dict[str, Callable[..., None]] = {
|
||||||
'disconnection': set_failure('connection_died'),
|
'disconnection': set_failure('connection_died'),
|
||||||
'pairing_failure': set_failure('pairing_failure'),
|
'pairing_failure': set_failure('pairing_failure'),
|
||||||
@@ -425,6 +443,7 @@ class SecurityService(SecurityServicer):
|
|||||||
'connection_encryption_change': on_encryption_change,
|
'connection_encryption_change': on_encryption_change,
|
||||||
'classic_pairing': try_set_success,
|
'classic_pairing': try_set_success,
|
||||||
'classic_pairing_failure': set_failure('pairing_failure'),
|
'classic_pairing_failure': set_failure('pairing_failure'),
|
||||||
|
'security_request': pair,
|
||||||
}
|
}
|
||||||
|
|
||||||
# register event handlers
|
# register event handlers
|
||||||
@@ -452,6 +471,15 @@ class SecurityService(SecurityServicer):
|
|||||||
pass
|
pass
|
||||||
self.log.debug('Authenticated')
|
self.log.debug('Authenticated')
|
||||||
|
|
||||||
|
# wait for `pair` to finish if any
|
||||||
|
if pair_task is not None:
|
||||||
|
self.log.debug('Wait for authentication...')
|
||||||
|
try:
|
||||||
|
await pair_task # type: ignore
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.log.debug('paired')
|
||||||
|
|
||||||
return WaitSecurityResponse(**kwargs)
|
return WaitSecurityResponse(**kwargs)
|
||||||
|
|
||||||
def reached_security_level(
|
def reached_security_level(
|
||||||
@@ -523,7 +551,7 @@ class SecurityStorageService(SecurityStorageServicer):
|
|||||||
self.log.debug(f"DeleteBond: {address}")
|
self.log.debug(f"DeleteBond: {address}")
|
||||||
|
|
||||||
if self.device.keystore is not None:
|
if self.device.keystore is not None:
|
||||||
with suppress(KeyError):
|
with contextlib.suppress(KeyError):
|
||||||
await self.device.keystore.delete(str(address))
|
await self.device.keystore.delete(str(address))
|
||||||
|
|
||||||
return empty_pb2.Empty()
|
return empty_pb2.Empty()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
|
|||||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||||
|
|
||||||
|
def on_smp_security_request_command(
|
||||||
|
self, connection: Connection, request: SMP_Security_Request_Command
|
||||||
|
) -> None:
|
||||||
|
connection.emit('security_request', request.auth_req)
|
||||||
|
|
||||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||||
|
# Parse the L2CAP payload into an SMP Command object
|
||||||
|
command = SMP_Command.from_bytes(pdu)
|
||||||
|
logger.debug(
|
||||||
|
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||||
|
f'{connection.peer_address}: {command}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Security request is more than just pairing, so let applications handle them
|
||||||
|
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||||
|
self.on_smp_security_request_command(
|
||||||
|
connection, cast(SMP_Security_Request_Command, command)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Look for a session with this connection, and create one if none exists
|
# Look for a session with this connection, and create one if none exists
|
||||||
if not (session := self.sessions.get(connection.handle)):
|
if not (session := self.sessions.get(connection.handle)):
|
||||||
if connection.role == BT_CENTRAL_ROLE:
|
if connection.role == BT_CENTRAL_ROLE:
|
||||||
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
|
|||||||
)
|
)
|
||||||
self.sessions[connection.handle] = session
|
self.sessions[connection.handle] = session
|
||||||
|
|
||||||
# Parse the L2CAP payload into an SMP Command object
|
|
||||||
command = SMP_Command.from_bytes(pdu)
|
|
||||||
logger.debug(
|
|
||||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
|
||||||
f'{connection.peer_address}: {command}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delegate the handling of the command to the session
|
# Delegate the handling of the command to the session
|
||||||
session.on_smp_command(command)
|
session.on_smp_command(command)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user