Files
network_audio_streaming/src/mic_rtc_streaming/backend/backend.py

86 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# backend/main.py
import asyncio, logging, uuid
from typing import List, Set
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
# Global: desired packetization time in ms for Opus (change here to affect both backend and frontend)
PTIME = 40
app = FastAPI()
# Allow CORS for frontend on localhost
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # You can restrict this to ["http://localhost:8501"] if you want
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
pcs: Set[RTCPeerConnection] = set() # keep refs so they dont GC early
class Offer(BaseModel):
sdp: str
type: str
@app.post("/offer")
async def offer(offer: Offer):
logging.info("/offer endpoint called")
pc = RTCPeerConnection() # No STUN needed for localhost
pcs.add(pc)
id_ = uuid.uuid4().hex[:8]
logging.info(f"{id_}: new PeerConnection")
@pc.on("track")
async def on_track(track: MediaStreamTrack):
logging.info(f"{id_}: track {track.kind} received")
try:
first = True
while True:
pkt = await track.recv() # RTP audio frame (already decrypted)
pkt_bytes = bytes(pkt.planes[0])
if first:
logging.info(
f"{id_}: frame sample_rate={pkt.sample_rate}, channels={pkt.layout.channels}, format={pkt.format.name}"
)
logging.info(f"{id_}: received audio frame of len {len(pkt_bytes)} bytes") #received audio frame of len 23040 bytes
first = False
# TODO: write to file, pipe to ASR, etc.
except Exception as e:
logging.error(f"{id_}: Exception in on_track: {e}")
# --- SDP negotiation ---
logging.info(f"{id_}: setting remote description")
await pc.setRemoteDescription(RTCSessionDescription(**offer.dict()))
logging.info(f"{id_}: creating answer")
answer = await pc.createAnswer()
sdp = answer.sdp
# Insert a=ptime using the global PTIME variable
ptime_line = f"a=ptime:{PTIME}"
if "a=sendrecv" in sdp:
sdp = sdp.replace("a=sendrecv", f"a=sendrecv\n{ptime_line}")
else:
sdp += f"\n{ptime_line}"
new_answer = RTCSessionDescription(sdp=sdp, type=answer.type)
await pc.setLocalDescription(new_answer)
logging.info(f"{id_}: sending answer with {ptime_line}")
return {"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type}
@app.on_event("shutdown")
async def cleanup():
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
if __name__ == "__main__":
import uvicorn
logging.basicConfig(level=logging.INFO)
uvicorn.run("backend:app", host="0.0.0.0", port=8000)