241 lines
7.4 KiB
Python
241 lines
7.4 KiB
Python
import asyncio
|
|
import uuid
|
|
import time
|
|
from typing import final
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
app = FastAPI()
|
|
|
|
# Enable CORS for local dev
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
class Question(BaseModel):
|
|
text: str
|
|
choices: list[str]
|
|
answer: str
|
|
seconds: int = 20
|
|
|
|
|
|
class QuizConfig(BaseModel):
|
|
title: str = "Trivia"
|
|
questions: list[Question]
|
|
|
|
|
|
@final
|
|
class Participant:
|
|
def __init__(self, name: str, ws: WebSocket):
|
|
self.name = name
|
|
self.ws = ws
|
|
self.score = 0
|
|
self.current_answer: str | None = None
|
|
|
|
|
|
@final
|
|
class GameSession:
|
|
def __init__(self, host_ws: WebSocket, questions: list[Question]):
|
|
self.id = str(uuid.uuid4())[:4].upper() # Simple 4-char join code
|
|
self.host_ws = host_ws
|
|
self.questions = questions
|
|
self.participants: dict[int, Participant] = {} # Keyed by websocket
|
|
self.current_q_index = -1
|
|
self.state = "LOBBY" # LOBBY, QUESTION, ANSWER, END
|
|
self.timer_task = None
|
|
self.end_time = 0.0
|
|
|
|
|
|
sessions: dict[str, GameSession] = {}
|
|
|
|
|
|
async def broadcast_state(session_id: str):
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
return
|
|
|
|
# Build Participant List (Public info)
|
|
player_list = [
|
|
{"name": p.name, "score": p.score, "answered": p.current_answer is not None}
|
|
for p in session.participants.values()
|
|
]
|
|
|
|
# Sort for leaderboard if ended
|
|
if session.state == "END":
|
|
player_list.sort(key=lambda x: x["score"], reverse=True)
|
|
|
|
# Current Question Info
|
|
q_data = None
|
|
if 0 <= session.current_q_index < len(session.questions):
|
|
q = session.questions[session.current_q_index]
|
|
q_data = {
|
|
"text": q.text,
|
|
"choices": q.choices,
|
|
"seconds": q.seconds,
|
|
# Only reveal answer in ANSWER state
|
|
"correct_answer": q.answer if session.state == "ANSWER" else None,
|
|
"answer_stats": {}, # Calculate stats for host view
|
|
}
|
|
|
|
if session.state == "ANSWER":
|
|
counts = {c: 0 for c in q.choices}
|
|
for p in session.participants.values():
|
|
if p.current_answer in counts:
|
|
counts[p.current_answer] += 1
|
|
q_data["answer_stats"] = counts
|
|
|
|
state_msg = {
|
|
"type": "STATE_UPDATE",
|
|
"state": session.state,
|
|
"players": player_list,
|
|
"question": q_data,
|
|
"end_time": session.end_time, # Sync timers
|
|
}
|
|
|
|
# Broadcast to Host
|
|
try:
|
|
await session.host_ws.send_json(state_msg)
|
|
except:
|
|
pass # Host disconnected handling omitted for brevity
|
|
|
|
# Broadcast to Participants
|
|
for p in session.participants.values():
|
|
try:
|
|
await p.ws.send_json(state_msg)
|
|
except:
|
|
pass
|
|
|
|
|
|
async def end_question_timer(session_id: str, q_index: int):
|
|
"""Wait for timer, then force transition if still on same question"""
|
|
session = sessions.get(session_id)
|
|
if not session:
|
|
return
|
|
|
|
seconds = session.questions[q_index].seconds
|
|
await asyncio.sleep(seconds)
|
|
|
|
# Refresh session state
|
|
if session.state == "QUESTION" and session.current_q_index == q_index:
|
|
await transition_to_answer(session_id)
|
|
|
|
|
|
async def transition_to_answer(session_id: str):
|
|
session = sessions.get(session_id)
|
|
assert session is not None
|
|
session.state = "ANSWER"
|
|
|
|
# Calculate scores
|
|
current_q = session.questions[session.current_q_index]
|
|
for p in session.participants.values():
|
|
if p.current_answer == current_q.answer:
|
|
p.score += 100 # Simple scoring
|
|
|
|
await broadcast_state(session_id)
|
|
|
|
|
|
# --- Routes ---
|
|
@app.websocket("/ws/host")
|
|
async def websocket_host(websocket: WebSocket):
|
|
await websocket.accept()
|
|
session = None
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
action = data.get("action")
|
|
|
|
if action == "CREATE_SESSION":
|
|
quiz_data = data.get("quiz") # Parsed YAML
|
|
questions = [Question(**q) for q in quiz_data["questions"]]
|
|
session = GameSession(websocket, questions)
|
|
sessions[session.id] = session
|
|
|
|
# Send code back to host
|
|
await websocket.send_json(
|
|
{"type": "SESSION_CREATED", "code": session.id}
|
|
)
|
|
await broadcast_state(session.id)
|
|
|
|
elif action == "START_GAME":
|
|
if session:
|
|
session.current_q_index = 0
|
|
session.state = "QUESTION"
|
|
session.end_time = time.time() + session.questions[0].seconds
|
|
# Reset player answers
|
|
for p in session.participants.values():
|
|
p.current_answer = None
|
|
|
|
await broadcast_state(session.id)
|
|
# Start background timer
|
|
asyncio.create_task(end_question_timer(session.id, 0))
|
|
|
|
elif action == "NEXT_QUESTION":
|
|
if session:
|
|
if session.current_q_index + 1 < len(session.questions):
|
|
session.current_q_index += 1
|
|
session.state = "QUESTION"
|
|
session.end_time = (
|
|
time.time()
|
|
+ session.questions[session.current_q_index].seconds
|
|
)
|
|
for p in session.participants.values():
|
|
p.current_answer = None
|
|
await broadcast_state(session.id)
|
|
_ = asyncio.create_task(
|
|
end_question_timer(session.id, session.current_q_index)
|
|
)
|
|
else:
|
|
session.state = "END"
|
|
await broadcast_state(session.id)
|
|
|
|
except WebSocketDisconnect:
|
|
if session and session.id in sessions:
|
|
del sessions[session.id]
|
|
|
|
|
|
@app.websocket("/ws/play/{code}/{name}")
|
|
async def websocket_player(websocket: WebSocket, code: str, name: str):
|
|
code = code.upper()
|
|
if code not in sessions:
|
|
await websocket.close(code=4000)
|
|
return
|
|
|
|
await websocket.accept()
|
|
session = sessions[code]
|
|
|
|
# Add participant
|
|
player = Participant(name, websocket)
|
|
session.participants[id(websocket)] = player
|
|
|
|
await broadcast_state(code)
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
action = data.get("action")
|
|
|
|
if action == "SUBMIT_ANSWER":
|
|
if session.state == "QUESTION":
|
|
player.current_answer = data.get("choice")
|
|
await broadcast_state(code) # Update "Answered" status on host
|
|
|
|
# Check if everyone answered
|
|
all_answered = all(
|
|
p.current_answer is not None
|
|
for p in session.participants.values()
|
|
)
|
|
if all_answered:
|
|
await transition_to_answer(code)
|
|
|
|
except WebSocketDisconnect:
|
|
if id(websocket) in session.participants:
|
|
del session.participants[id(websocket)]
|
|
await broadcast_state(code)
|