Replace liblc3 wasm library

This commit is contained in:
Josh Wu
2024-12-18 20:45:39 +08:00
parent 99695bb264
commit a27f55a588
3 changed files with 70 additions and 198 deletions

View File

@@ -16,23 +16,22 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import wave
import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
try:
import lc3 # type: ignore # pylint: disable=E0401
except ImportError as e:
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
import click
import aiohttp.web
@@ -45,6 +44,7 @@ from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -54,6 +54,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
def _sink_pac_record() -> pacs.PacRecord:
@@ -100,153 +101,8 @@ def _source_pac_record() -> pacs.PacRecord:
)
# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
decoder: lc3.Decoder | None = None
encoding_config: bap.CodecSpecificConfiguration | None = None
async def lc3_source_task(
@@ -256,42 +112,58 @@ async def lc3_source_task(
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
logger.info(
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
filename,
sdu_length,
frame_duration_us / 1000,
)
with wave.open(filename, 'rb') as wav:
bits_per_sample = wav.getsampwidth() * 8
packet_sequence_number = 0
encoder: lc3.Encoder | None = None
while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
if not encoder:
if (
encoding_config
and (frame_duration := encoding_config.frame_duration)
and (sampling_frequency := encoding_config.sampling_frequency)
and (
audio_channel_allocation := encoding_config.audio_channel_allocation
)
):
logger.info("Use %s", encoding_config)
encoder = lc3.Encoder(
frame_duration_us=frame_duration.us,
sample_rate_hz=sampling_frequency.hz,
num_channels=audio_channel_allocation.channel_count,
input_sample_rate_hz=wav.getframerate(),
)
else:
sdu = encoder.encode(
pcm=wav.readframes(encoder.get_frame_samples()),
num_bytes=sdu_length,
bit_depth=bits_per_sample,
)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
await asyncio.sleep(sleep_time.total_seconds() * 0.9)
# -----------------------------------------------------------------------------
@@ -410,7 +282,7 @@ class Speaker:
def __init__(
self,
device_config_path: Optional[str],
device_config_path: str | None,
ui_port: int,
transport: str,
lc3_input_file_path: str,
@@ -490,12 +362,12 @@ class Speaker:
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
or decoder is None
or not pdu.iso_sdu_fragment
):
return
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
pcm = decoder.decode(
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
@@ -537,16 +409,14 @@ class Speaker:
):
return
if ase.role == ascs.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
global encoding_config
encoding_config = codec_config
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
global decoder
decoder = lc3.Decoder(
frame_duration_us=codec_config.frame_duration.us,
sample_rate_hz=codec_config.sampling_frequency.hz,
num_channels=codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs_service.ase_state_machines.values():
@@ -585,7 +455,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
speaker()

Binary file not shown.

View File

@@ -54,7 +54,6 @@ development = [
"types-appdirs >= 1.4.3",
"types-invoke >= 1.7.3",
"types-protobuf >= 4.21.0",
"wasmtime == 20.0.0",
]
avatar = [
"pandora-avatar == 0.0.10",
@@ -66,6 +65,9 @@ documentation = [
"mkdocs-material >= 8.5.6",
"mkdocstrings[python] >= 0.19.0",
]
lc3 = [
"lc3 @ git+https://github.com/google/liblc3.git",
]
[project.scripts]
bumble-ble-rpa-tool = "bumble.apps.ble_rpa_tool:main"