gwoe-antragspruefer/app/queue.py

304 lines
11 KiB
Python
Raw Normal View History

"""Analysis job queue with configurable parallel workers (#95, #99).
Processes jobs via an asyncio.Queue with N concurrent workers (Semaphore).
Tracks per-job status for live UI visualization.
"""
import asyncio
import logging
import os
import time
from typing import Any, Callable, Coroutine, Optional
logger = logging.getLogger(__name__)
# Konfiguration
MAX_QUEUE_SIZE = 50
CONCURRENCY = int(os.environ.get("QUEUE_CONCURRENCY", "3"))
MIN_PAUSE_SECONDS = 3
_shutting_down = False # Sperrt neue Jobs bei Graceful Shutdown
BACKOFF_BASE = 15
BACKOFF_MAX = 300
# In-Memory Queue + Job-Tracking
_queue: asyncio.Queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
_worker_tasks: list[asyncio.Task] = []
_stats = {
"processed": 0,
"failed": 0,
"started_at": None,
"avg_duration": 60.0,
}
# Live Job-Tracking: job_id → {status, drucksache, started_at, duration, error}
_jobs: dict[str, dict] = {}
_MAX_TRACKED_JOBS = 100 # Älteste Jobs werden verworfen
class QueueFullError(Exception):
pass
async def enqueue(
job_id: str,
callback: Callable[..., Coroutine],
*args: Any,
drucksache: str = "",
**kwargs: Any,
) -> int:
"""Add a job to the queue. Returns queue position."""
if _shutting_down:
raise QueueFullError("Server wird neu gestartet. Bitte in Kürze erneut versuchen.")
try:
_queue.put_nowait((job_id, callback, args, kwargs))
except asyncio.QueueFull:
raise QueueFullError(f"Queue voll ({MAX_QUEUE_SIZE} Jobs).")
_jobs[job_id] = {
"status": "queued",
"drucksache": drucksache,
"enqueued_at": time.time(),
"started_at": None,
"duration": None,
"error": None,
}
# Alte Jobs trimmen
if len(_jobs) > _MAX_TRACKED_JOBS:
oldest = sorted(_jobs, key=lambda k: _jobs[k].get("enqueued_at", 0))
for k in oldest[:len(_jobs) - _MAX_TRACKED_JOBS]:
del _jobs[k]
position = _queue.qsize()
logger.info("Job %s enqueued at position %d (concurrency=%d)", job_id, position, CONCURRENCY)
return position
def get_queue_status() -> dict:
"""Queue status + per-job details for UI visualization."""
pending = _queue.qsize()
avg = _stats["avg_duration"]
# Bei N Workern teilt sich die Wartezeit
estimated_wait = (pending / max(CONCURRENCY, 1)) * (avg + MIN_PAUSE_SECONDS)
# Jobs nach Status gruppieren
recent_jobs = sorted(_jobs.values(), key=lambda j: j.get("enqueued_at", 0), reverse=True)[:30]
# Stale jobs aus DB laden (nach Container-Restart)
stale_jobs = []
try:
import sqlite3
from .config import settings
conn = sqlite3.connect(settings.db_path)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, bundesland, status, created_at FROM jobs "
"WHERE status IN ('stale', 'queued', 'processing') ORDER BY created_at DESC LIMIT 20"
).fetchall()
conn.close()
stale_jobs = [{"job_id": r["id"], "bundesland": r["bundesland"] or "",
"status": "stale", "drucksache": r["drucksache"] if "drucksache" in r.keys() else "",
"duration": None, "error": "Container-Restart"} for r in rows]
except Exception:
pass
return {
"pending": pending,
"max_size": MAX_QUEUE_SIZE,
"concurrency": CONCURRENCY,
"shutting_down": _shutting_down,
"processed_total": _stats["processed"],
"failed_total": _stats["failed"],
"estimated_wait_seconds": round(estimated_wait),
"avg_job_duration_seconds": round(avg, 1),
"workers_running": sum(1 for t in _worker_tasks if not t.done()),
"jobs": [{
"job_id": jid,
"drucksache": j.get("drucksache", ""),
"status": j["status"],
"duration": round(j["duration"], 1) if j.get("duration") else None,
"error": j.get("error"),
} for jid, j in list(_jobs.items())[-30:]] + stale_jobs,
}
async def _worker(worker_id: int):
"""Worker coroutine — picks jobs from queue, processes with Semaphore."""
logger.info("Worker %d started", worker_id)
consecutive_failures = 0
while True:
job_id, callback, args, kwargs = await _queue.get()
t0 = time.time()
if job_id in _jobs:
_jobs[job_id]["status"] = "processing"
_jobs[job_id]["started_at"] = t0
try:
logger.info("Worker %d processing %s (queue: %d)", worker_id, job_id, _queue.qsize())
await callback(*args, **kwargs)
duration = time.time() - t0
_stats["processed"] += 1
_stats["avg_duration"] = (_stats["avg_duration"] * 0.8) + (duration * 0.2)
consecutive_failures = 0
if job_id in _jobs:
_jobs[job_id]["status"] = "completed"
_jobs[job_id]["duration"] = duration
logger.info("Worker %d completed %s in %.1fs", worker_id, job_id, duration)
except Exception as e:
_stats["failed"] += 1
consecutive_failures += 1
if job_id in _jobs:
_jobs[job_id]["status"] = "failed"
_jobs[job_id]["duration"] = time.time() - t0
_jobs[job_id]["error"] = str(e)[:100]
logger.exception("Worker %d failed %s", worker_id, job_id)
if consecutive_failures > 1:
backoff = min(BACKOFF_BASE * (2 ** (consecutive_failures - 2)), BACKOFF_MAX)
logger.warning("Worker %d backoff %ds", worker_id, backoff)
await asyncio.sleep(backoff)
finally:
_queue.task_done()
await asyncio.sleep(MIN_PAUSE_SECONDS)
def start_worker() -> list[asyncio.Task]:
"""Start N worker coroutines."""
global _worker_tasks
_stats["started_at"] = time.time()
for i in range(CONCURRENCY):
if i < len(_worker_tasks) and not _worker_tasks[i].done():
continue
task = asyncio.create_task(_worker(i))
if i < len(_worker_tasks):
_worker_tasks[i] = task
else:
_worker_tasks.append(task)
logger.info("Queue: %d workers started (QUEUE_CONCURRENCY=%d)", CONCURRENCY, CONCURRENCY)
return _worker_tasks
async def graceful_shutdown(timeout: int = 900):
"""Graceful Shutdown: aktuell laufende Jobs beenden, Queue sperren.
1. Sperrt neue Jobs (_shutting_down = True)
2. Wartet bis alle gerade PROCESSING-Jobs fertig sind (max timeout)
3. Queued-Jobs bleiben in der DB als 'stale' User kann nach
Restart erneut triggern
Timeout 15 min (900s) ein einzelner LLM-Call dauert max ~120s,
bei 3 parallelen Workern also max ~120s reale Wartezeit.
"""
global _shutting_down
_shutting_down = True
processing = sum(1 for j in _jobs.values() if j.get("status") == "processing")
pending = _queue.qsize()
if processing == 0:
logger.info("Graceful shutdown: keine laufenden Jobs, sofort beenden (%d queued verworfen)", pending)
return
logger.warning("Graceful shutdown: warte auf %d laufende Jobs (max %ds). %d queued werden beim Restart stale.",
processing, timeout, pending)
# Warte nur auf die laufenden Jobs, nicht auf die ganze Queue
start = time.time()
while time.time() - start < timeout:
still_processing = sum(1 for j in _jobs.values() if j.get("status") == "processing")
if still_processing == 0:
logger.info("Graceful shutdown: alle laufenden Jobs beendet nach %.0fs", time.time() - start)
return
await asyncio.sleep(2)
logger.error("Graceful shutdown: Timeout nach %ds, %d Jobs noch aktiv",
timeout, sum(1 for j in _jobs.values() if j.get("status") == "processing"))
async def re_enqueue_pending(analysis_callback=None):
"""Re-enqueue jobs that were queued or processing when the container died.
Jobs WITH a drucksache column get re-enqueued automatically (if callback provided).
Jobs WITHOUT drucksache (legacy) get marked as stale and cleaned up.
Args:
analysis_callback: async function(job_id, drucksache, text, bundesland, model, doc)
"""
import aiosqlite
from .config import settings
async with aiosqlite.connect(settings.db_path) as db:
db.row_factory = aiosqlite.Row
rows = await db.execute(
"SELECT id, bundesland, drucksache, model FROM jobs "
"WHERE status IN ('queued', 'processing') ORDER BY created_at"
)
pending = await rows.fetchall()
if not pending:
# Alte stale-Jobs ohne drucksache aufräumen
async with aiosqlite.connect(settings.db_path) as db:
deleted = await db.execute(
"DELETE FROM jobs WHERE status='stale' AND (drucksache IS NULL OR drucksache='')"
)
if deleted.rowcount > 0:
logger.info("Cleaned up %d legacy stale jobs without drucksache", deleted.rowcount)
await db.commit()
return
logger.info("Found %d pending jobs from previous run", len(pending))
from .parlamente import get_adapter
re_enqueued = 0
marked_stale = 0
for row in pending:
job_id = row["id"]
bundesland = row["bundesland"] or "NRW"
drucksache = row["drucksache"]
model = row["model"] or "qwen-plus"
if not drucksache or not analysis_callback:
# Legacy-Job ohne Drucksache oder kein Callback → stale markieren
async with aiosqlite.connect(settings.db_path) as db:
await db.execute(
"UPDATE jobs SET status='stale', updated_at=datetime('now') WHERE id=?",
(job_id,),
)
await db.commit()
marked_stale += 1
continue
# Job mit Drucksache → neu enqueuen
try:
adapter = get_adapter(bundesland)
doc = await adapter.get_document(drucksache)
if not doc:
raise ValueError(f"Drucksache {drucksache} nicht gefunden")
text = await adapter.download_text(drucksache)
if not text:
raise ValueError(f"PDF-Text für {drucksache} leer")
position = await enqueue(
job_id,
analysis_callback,
job_id, drucksache, text, bundesland, model, doc,
drucksache=drucksache,
)
re_enqueued += 1
logger.info("Re-enqueued %s (%s) at position %d", drucksache, bundesland, position)
except Exception as e:
logger.warning("Could not re-enqueue %s (%s): %s — marking stale", drucksache, bundesland, e)
async with aiosqlite.connect(settings.db_path) as db:
await db.execute(
"UPDATE jobs SET status='stale', error=?, updated_at=datetime('now') WHERE id=?",
(str(e)[:200], job_id),
)
await db.commit()
marked_stale += 1
logger.info("Re-enqueued %d jobs, marked %d stale", re_enqueued, marked_stale)