"""Analysis job queue with configurable parallel workers (#95, #99). Processes jobs via an asyncio.Queue with N concurrent workers (Semaphore). Tracks per-job status for live UI visualization. """ import asyncio import logging import os import time from typing import Any, Callable, Coroutine, Optional logger = logging.getLogger(__name__) # Konfiguration MAX_QUEUE_SIZE = 50 CONCURRENCY = int(os.environ.get("QUEUE_CONCURRENCY", "3")) MIN_PAUSE_SECONDS = 3 _shutting_down = False # Sperrt neue Jobs bei Graceful Shutdown BACKOFF_BASE = 15 BACKOFF_MAX = 300 # In-Memory Queue + Job-Tracking _queue: asyncio.Queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) _worker_tasks: list[asyncio.Task] = [] _stats = { "processed": 0, "failed": 0, "started_at": None, "avg_duration": 60.0, } # Live Job-Tracking: job_id → {status, drucksache, started_at, duration, error} _jobs: dict[str, dict] = {} _MAX_TRACKED_JOBS = 100 # Älteste Jobs werden verworfen class QueueFullError(Exception): pass async def enqueue( job_id: str, callback: Callable[..., Coroutine], *args: Any, drucksache: str = "", **kwargs: Any, ) -> int: """Add a job to the queue. Returns queue position.""" if _shutting_down: raise QueueFullError("Server wird neu gestartet. Bitte in Kürze erneut versuchen.") try: _queue.put_nowait((job_id, callback, args, kwargs)) except asyncio.QueueFull: raise QueueFullError(f"Queue voll ({MAX_QUEUE_SIZE} Jobs).") _jobs[job_id] = { "status": "queued", "drucksache": drucksache, "enqueued_at": time.time(), "started_at": None, "duration": None, "error": None, } # Alte Jobs trimmen if len(_jobs) > _MAX_TRACKED_JOBS: oldest = sorted(_jobs, key=lambda k: _jobs[k].get("enqueued_at", 0)) for k in oldest[:len(_jobs) - _MAX_TRACKED_JOBS]: del _jobs[k] position = _queue.qsize() logger.info("Job %s enqueued at position %d (concurrency=%d)", job_id, position, CONCURRENCY) return position def get_queue_status() -> dict: """Queue status + per-job details for UI visualization.""" pending = _queue.qsize() avg = _stats["avg_duration"] # Bei N Workern teilt sich die Wartezeit estimated_wait = (pending / max(CONCURRENCY, 1)) * (avg + MIN_PAUSE_SECONDS) # Jobs nach Status gruppieren recent_jobs = sorted(_jobs.values(), key=lambda j: j.get("enqueued_at", 0), reverse=True)[:30] return { "pending": pending, "max_size": MAX_QUEUE_SIZE, "concurrency": CONCURRENCY, "shutting_down": _shutting_down, "processed_total": _stats["processed"], "failed_total": _stats["failed"], "estimated_wait_seconds": round(estimated_wait), "avg_job_duration_seconds": round(avg, 1), "workers_running": sum(1 for t in _worker_tasks if not t.done()), "jobs": [{ "job_id": jid, "drucksache": j.get("drucksache", ""), "status": j["status"], "duration": round(j["duration"], 1) if j.get("duration") else None, "error": j.get("error"), } for jid, j in list(_jobs.items())[-30:]], } async def _worker(worker_id: int): """Worker coroutine — picks jobs from queue, processes with Semaphore.""" logger.info("Worker %d started", worker_id) consecutive_failures = 0 while True: job_id, callback, args, kwargs = await _queue.get() t0 = time.time() if job_id in _jobs: _jobs[job_id]["status"] = "processing" _jobs[job_id]["started_at"] = t0 try: logger.info("Worker %d processing %s (queue: %d)", worker_id, job_id, _queue.qsize()) await callback(*args, **kwargs) duration = time.time() - t0 _stats["processed"] += 1 _stats["avg_duration"] = (_stats["avg_duration"] * 0.8) + (duration * 0.2) consecutive_failures = 0 if job_id in _jobs: _jobs[job_id]["status"] = "completed" _jobs[job_id]["duration"] = duration logger.info("Worker %d completed %s in %.1fs", worker_id, job_id, duration) except Exception as e: _stats["failed"] += 1 consecutive_failures += 1 if job_id in _jobs: _jobs[job_id]["status"] = "failed" _jobs[job_id]["duration"] = time.time() - t0 _jobs[job_id]["error"] = str(e)[:100] logger.exception("Worker %d failed %s", worker_id, job_id) if consecutive_failures > 1: backoff = min(BACKOFF_BASE * (2 ** (consecutive_failures - 2)), BACKOFF_MAX) logger.warning("Worker %d backoff %ds", worker_id, backoff) await asyncio.sleep(backoff) finally: _queue.task_done() await asyncio.sleep(MIN_PAUSE_SECONDS) def start_worker() -> list[asyncio.Task]: """Start N worker coroutines.""" global _worker_tasks _stats["started_at"] = time.time() for i in range(CONCURRENCY): if i < len(_worker_tasks) and not _worker_tasks[i].done(): continue task = asyncio.create_task(_worker(i)) if i < len(_worker_tasks): _worker_tasks[i] = task else: _worker_tasks.append(task) logger.info("Queue: %d workers started (QUEUE_CONCURRENCY=%d)", CONCURRENCY, CONCURRENCY) return _worker_tasks async def graceful_shutdown(timeout: int = 900): """Graceful Shutdown: aktuell laufende Jobs beenden, Queue sperren. 1. Sperrt neue Jobs (_shutting_down = True) 2. Wartet bis alle gerade PROCESSING-Jobs fertig sind (max timeout) 3. Queued-Jobs bleiben in der DB als 'stale' → User kann nach Restart erneut triggern Timeout 15 min (900s) — ein einzelner LLM-Call dauert max ~120s, bei 3 parallelen Workern also max ~120s reale Wartezeit. """ global _shutting_down _shutting_down = True processing = sum(1 for j in _jobs.values() if j.get("status") == "processing") pending = _queue.qsize() if processing == 0: logger.info("Graceful shutdown: keine laufenden Jobs, sofort beenden (%d queued verworfen)", pending) return logger.warning("Graceful shutdown: warte auf %d laufende Jobs (max %ds). %d queued werden beim Restart stale.", processing, timeout, pending) # Warte nur auf die laufenden Jobs, nicht auf die ganze Queue start = time.time() while time.time() - start < timeout: still_processing = sum(1 for j in _jobs.values() if j.get("status") == "processing") if still_processing == 0: logger.info("Graceful shutdown: alle laufenden Jobs beendet nach %.0fs", time.time() - start) return await asyncio.sleep(2) logger.error("Graceful shutdown: Timeout nach %ds, %d Jobs noch aktiv", timeout, sum(1 for j in _jobs.values() if j.get("status") == "processing")) async def re_enqueue_pending(): """Mark stale queued jobs from previous run.""" import aiosqlite from .config import settings async with aiosqlite.connect(settings.db_path) as db: db.row_factory = aiosqlite.Row rows = await db.execute("SELECT id FROM jobs WHERE status = 'queued' ORDER BY created_at") queued = await rows.fetchall() if not queued: return async with aiosqlite.connect(settings.db_path) as db: for row in queued: await db.execute( "UPDATE jobs SET status = 'stale', updated_at = datetime('now') WHERE id = ?", (row["id"],), ) await db.commit() logger.info("Marked %d stale jobs", len(queued))