gwoe-antragspruefer/app/reindex_embeddings.py

235 lines
7.7 KiB
Python
Raw Permalink Normal View History

"""Reindex-Script für die Embedding-Modell-Migration v3 → v4 (Issue #123).
Läuft im Container:
docker exec gwoe-antragspruefer python -m app.reindex_embeddings
Was es macht:
1. Alle Wahlprogramme + Grundsatzprogramme mit dem aktuellen EMBEDDING_MODEL
(aus settings.embedding_model_write, default 'text-embedding-v4') neu
indexieren. Schreibt neue Rows in chunks mit model='text-embedding-v4',
die bestehenden v3-Rows bleiben unberührt.
2. Alle Assessments backfillen: summary_embedding erzeugen wo NULL oder wo
embedding_model vom aktuellen abweicht.
3. Rate-Limit: 100ms zwischen Calls (= max 10 req/sec).
4. Fortschritts-Logging pro Programm/Assessment.
Nach erfolgreichem Lauf:
- settings.embedding_model_read auf 'text-embedding-v4' flippen (via ENV),
Container neu starten
- Script `cleanup_v3_rows.py` läuft DELETE FROM chunks WHERE model='text-embedding-v3'
"""
import asyncio
import json
import logging
import sqlite3
import time
from pathlib import Path
import aiosqlite
from .config import settings
from .embeddings import (
EMBEDDING_BATCH_SIZE,
EMBEDDING_MODEL,
EMBEDDINGS_DB,
PROGRAMME,
create_embedding,
create_embeddings_batch,
init_embeddings_db,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
RATE_LIMIT_SLEEP = 0.1 # 100ms = 10 req/sec
def reindex_programme(pdf_dir: Path) -> dict:
"""Re-index all programs with the current WRITE model."""
init_embeddings_db()
# Welche Programme sind bereits mit dem aktuellen Modell indexiert?
conn = sqlite3.connect(EMBEDDINGS_DB)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT programm_id, COUNT(*) AS n FROM chunks WHERE model = ? GROUP BY programm_id",
(EMBEDDING_MODEL,),
).fetchall()
already_done = {r["programm_id"]: r["n"] for r in rows}
conn.close()
stats = {"reindexed": 0, "skipped": 0, "failed": 0, "total_chunks": 0}
for prog_id, info in PROGRAMME.items():
if prog_id in already_done:
logger.info(
"SKIP %s — bereits %d chunks mit %s",
prog_id, already_done[prog_id], EMBEDDING_MODEL,
)
stats["skipped"] += 1
continue
pdf_path = pdf_dir / info["pdf"]
if not pdf_path.exists():
logger.warning("MISS %s — PDF fehlt: %s", prog_id, pdf_path)
stats["failed"] += 1
continue
try:
logger.info("INDEX %s (%s)", prog_id, info["pdf"])
n = _index_programm_with_ratelimit(prog_id, pdf_dir)
stats["reindexed"] += 1
stats["total_chunks"] += n
logger.info("DONE %s%d chunks", prog_id, n)
except Exception:
logger.exception("FAIL %s", prog_id)
stats["failed"] += 1
return stats
def _index_programm_with_ratelimit(programm_id: str, pdf_dir: Path) -> int:
"""Batch-Reindex: sammelt alle Chunks, embedded in Batches von
EMBEDDING_BATCH_SIZE (10) Texten pro API-Call. ~10× schneller als
Single-Call-Loop."""
import fitz
info = PROGRAMME[programm_id]
pdf_path = pdf_dir / info["pdf"]
conn = sqlite3.connect(EMBEDDINGS_DB)
# Nur die Rows des aktuellen Modells löschen (Migration-sicher)
conn.execute(
"DELETE FROM chunks WHERE programm_id = ? AND model = ?",
(programm_id, EMBEDDING_MODEL),
)
# Erst alle Chunks sammeln, dann in Batches embedden
doc = fitz.open(pdf_path)
pending: list[tuple[int, str]] = [] # (page_num, chunk_text)
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text()
if not text.strip():
continue
words = text.split()
i = 0
chunk_size, overlap = 400, 50
while i < len(words):
chunk = " ".join(words[i : i + chunk_size])
i += chunk_size - overlap
if len(chunk.split()) < 20:
continue
pending.append((page_num + 1, chunk))
doc.close()
total = 0
# Batches à BATCH_SIZE
for start in range(0, len(pending), EMBEDDING_BATCH_SIZE):
batch = pending[start : start + EMBEDDING_BATCH_SIZE]
texts = [t for _, t in batch]
try:
vecs = create_embeddings_batch(texts, model=EMBEDDING_MODEL)
time.sleep(RATE_LIMIT_SLEEP) # 100ms zwischen Batch-Calls
except Exception:
logger.exception("batch failed (programm %s, start %d)", programm_id, start)
continue
for (page_num, chunk), vec in zip(batch, vecs):
conn.execute(
"INSERT INTO chunks (programm_id, partei, typ, seite, text, embedding, bundesland, model) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
programm_id,
info["partei"],
info["typ"],
page_num,
chunk,
json.dumps(vec).encode(),
info.get("bundesland"),
EMBEDDING_MODEL,
),
)
total += 1
# Commit pro Batch, damit im Crash-Fall nicht alles verloren ist
conn.commit()
conn.close()
return total
async def backfill_assessment_embeddings() -> dict:
"""Alle Assessments ohne Embedding (oder mit altem Modell) nachziehen."""
from .embeddings import create_assessment_embedding
stats = {"backfilled": 0, "skipped": 0, "failed": 0}
async with aiosqlite.connect(settings.db_path) as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT drucksache, title, antrag_zusammenfassung, themen, bundesland, embedding_model "
"FROM assessments"
)
rows = await cur.fetchall()
for row in rows:
if row["embedding_model"] == EMBEDDING_MODEL:
stats["skipped"] += 1
continue
try:
themen = json.loads(row["themen"] or "[]")
except Exception:
themen = []
blob, model = create_assessment_embedding(
title=row["title"] or "",
zusammenfassung=row["antrag_zusammenfassung"],
themen=themen,
bundesland=row["bundesland"],
)
time.sleep(RATE_LIMIT_SLEEP)
if blob is None:
stats["failed"] += 1
logger.warning("backfill FAIL %s", row["drucksache"])
continue
async with aiosqlite.connect(settings.db_path) as db:
await db.execute(
"UPDATE assessments SET summary_embedding = ?, embedding_model = ? WHERE drucksache = ?",
(blob, model, row["drucksache"]),
)
await db.commit()
stats["backfilled"] += 1
if stats["backfilled"] % 20 == 0:
logger.info("backfill progress: %d", stats["backfilled"])
return stats
async def main():
pdf_dir = Path(__file__).resolve().parent / "static" / "referenzen"
logger.info("=" * 60)
logger.info("Reindex mit WRITE-Modell: %s", EMBEDDING_MODEL)
logger.info("PDF-Verzeichnis: %s", pdf_dir)
logger.info("=" * 60)
prog_stats = reindex_programme(pdf_dir)
logger.info("Programme fertig: %s", prog_stats)
logger.info("Backfill Assessment-Embeddings …")
ass_stats = await backfill_assessment_embeddings()
logger.info("Assessments fertig: %s", ass_stats)
logger.info("=" * 60)
logger.info("REINDEX KOMPLETT")
logger.info("Programme: %s", prog_stats)
logger.info("Assessments: %s", ass_stats)
logger.info("Nächster Schritt: settings.embedding_model_read auf %s setzen", EMBEDDING_MODEL)
logger.info("(ENV: EMBEDDING_MODEL_READ=%s, Container neu starten)", EMBEDDING_MODEL)
if __name__ == "__main__":
asyncio.run(main())