feat: implement websockets

This commit is contained in:
Radu C. Martin 2025-04-14 10:10:01 +02:00
parent c8abb8943e
commit bf3fceb833
3 changed files with 211 additions and 132 deletions

121
main.py
View file

@ -1,7 +1,9 @@
import asyncio
from enum import Enum
from fastapi import FastAPI
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from download_service import DownloadService
from music_player import MusicPlayer, PlayerState
@ -14,8 +16,47 @@ class ChangePlayerState(Enum):
stop = "stop"
queue: list[str] = []
class WSConnectionType(Enum):
state = "state"
queue = "queue"
class ConnectionManager:
def __init__(self) -> None:
self.active_connections: dict[str, set[WebSocket]] = {
WSConnectionType.state.value: set(),
WSConnectionType.queue.value: set(),
}
async def connect(self, websocket: WebSocket, type: WSConnectionType):
await websocket.accept()
self.active_connections[type.value].add(websocket)
async def send(self, ws: WebSocket, message: BaseModel):
try:
await ws.send_json(message.model_dump())
except Exception:
self.disconnect(ws)
async def broadcast(self, ws_type: WSConnectionType, message: BaseModel):
broken = set()
for ws in self.active_connections[ws_type.value]:
try:
await ws.send_json(message.model_dump())
except Exception:
broken.add(ws)
for ws in broken:
self.disconnect(ws)
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections[WSConnectionType.state.value]:
self.active_connections[WSConnectionType.state.value].remove(websocket)
if websocket in self.active_connections[WSConnectionType.queue.value]:
self.active_connections[WSConnectionType.queue.value].remove(websocket)
# Setup
tags_metadata = [
{"name": "player", "description": "Interact with the Music Player"},
{"name": "experimental"},
@ -24,6 +65,7 @@ tags_metadata = [
app = FastAPI(openapi_tags=tags_metadata)
player = MusicPlayer()
dl_service = DownloadService()
ws_manager = ConnectionManager()
# Interface
@ -33,47 +75,80 @@ async def root():
return f.read()
# Experimental
@app.on_event("startup")
async def start_event_loop():
asyncio.create_task(state_broadcast_loop())
asyncio.create_task(queue_broadcast_loop())
async def state_broadcast_loop():
while True:
await player._state_event.wait()
await ws_manager.broadcast(WSConnectionType.state, player.get_state())
player._state_event.clear()
async def queue_broadcast_loop():
while True:
await player._queue_event.wait()
await ws_manager.broadcast(WSConnectionType.queue, player.get_queue())
player._queue_event.clear()
# Status updates
@app.websocket("/player")
async def websocket_player(websocket: WebSocket):
await ws_manager.connect(websocket, WSConnectionType.state)
try:
while True:
await websocket.receive_text()
await ws_manager.send(websocket, player.get_state())
except WebSocketDisconnect:
ws_manager.disconnect(websocket)
# Queue updates
@app.websocket("/queue")
async def websocket_queue(websocket: WebSocket):
await ws_manager.connect(websocket, WSConnectionType.queue)
try:
while True:
await websocket.receive_text()
await ws_manager.send(websocket, player.get_queue())
except WebSocketDisconnect:
ws_manager.disconnect(websocket)
@app.get("/queue", tags=["queue"])
def get_queue():
return player.get_queue()
@app.post("/queue", tags=["queue"])
def post_to_queue(url: str):
async def post_to_queue(url: str):
track = dl_service.download(url)
player.add_to_queue(track)
await player.add_to_queue(track)
@app.post("/player/play", tags=["player"])
def player_play():
player.play()
@app.post("/player/pause", tags=["player"])
def player_pause():
player.pause()
@app.post("/player/resume", tags=["player"])
def player_resume():
player.resume()
async def player_play():
await player.play()
@app.post("/player/stop", tags=["player"])
def player_stop():
player.stop()
async def player_stop():
await player.stop()
@app.post("/player/skip", tags=["player"])
def player_skip():
player.next()
async def player_skip():
await player.next()
# Player
@app.put("/player/volume", tags=["player"])
def set_volume(volume: float):
player.set_volume(volume)
async def set_volume(volume: float):
await player.set_volume(volume)
@app.get("/player/volume", tags=["player"])