86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
# backend/main.py
|
||
import asyncio, logging, uuid
|
||
|
||
from typing import List, Set
|
||
|
||
from fastapi import FastAPI, WebSocket
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||
|
||
# Global: desired packetization time in ms for Opus (change here to affect both backend and frontend)
|
||
PTIME = 40
|
||
|
||
app = FastAPI()
|
||
|
||
# Allow CORS for frontend on localhost
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # You can restrict this to ["http://localhost:8501"] if you want
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
pcs: Set[RTCPeerConnection] = set() # keep refs so they don’t GC early
|
||
|
||
class Offer(BaseModel):
|
||
sdp: str
|
||
type: str
|
||
|
||
@app.post("/offer")
|
||
async def offer(offer: Offer):
|
||
logging.info("/offer endpoint called")
|
||
|
||
pc = RTCPeerConnection() # No STUN needed for localhost
|
||
pcs.add(pc)
|
||
id_ = uuid.uuid4().hex[:8]
|
||
logging.info(f"{id_}: new PeerConnection")
|
||
|
||
@pc.on("track")
|
||
async def on_track(track: MediaStreamTrack):
|
||
logging.info(f"{id_}: track {track.kind} received")
|
||
try:
|
||
first = True
|
||
while True:
|
||
pkt = await track.recv() # RTP audio frame (already decrypted)
|
||
pkt_bytes = bytes(pkt.planes[0])
|
||
if first:
|
||
logging.info(
|
||
f"{id_}: frame sample_rate={pkt.sample_rate}, channels={pkt.layout.channels}, format={pkt.format.name}"
|
||
)
|
||
logging.info(f"{id_}: received audio frame of len {len(pkt_bytes)} bytes") #received audio frame of len 23040 bytes
|
||
first = False
|
||
# TODO: write to file, pipe to ASR, etc.
|
||
except Exception as e:
|
||
logging.error(f"{id_}: Exception in on_track: {e}")
|
||
|
||
# --- SDP negotiation ---
|
||
logging.info(f"{id_}: setting remote description")
|
||
await pc.setRemoteDescription(RTCSessionDescription(**offer.dict()))
|
||
|
||
logging.info(f"{id_}: creating answer")
|
||
answer = await pc.createAnswer()
|
||
sdp = answer.sdp
|
||
# Insert a=ptime using the global PTIME variable
|
||
ptime_line = f"a=ptime:{PTIME}"
|
||
if "a=sendrecv" in sdp:
|
||
sdp = sdp.replace("a=sendrecv", f"a=sendrecv\n{ptime_line}")
|
||
else:
|
||
sdp += f"\n{ptime_line}"
|
||
new_answer = RTCSessionDescription(sdp=sdp, type=answer.type)
|
||
await pc.setLocalDescription(new_answer)
|
||
logging.info(f"{id_}: sending answer with {ptime_line}")
|
||
return {"sdp": pc.localDescription.sdp,
|
||
"type": pc.localDescription.type}
|
||
|
||
@app.on_event("shutdown")
|
||
async def cleanup():
|
||
coros = [pc.close() for pc in pcs]
|
||
await asyncio.gather(*coros)
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
logging.basicConfig(level=logging.INFO)
|
||
uvicorn.run("backend:app", host="0.0.0.0", port=8000)
|