gwoe-antragspruefer/app/embeddings.py

388 lines
11 KiB
Python
Raw Normal View History

"""Semantic search for Wahlprogramme and Parteiprogramme using Qwen embeddings."""
import json
import sqlite3
from pathlib import Path
from typing import Optional
import fitz # PyMuPDF
from openai import OpenAI
from .config import settings
# Embedding model
EMBEDDING_MODEL = "text-embedding-v3"
EMBEDDING_DIMENSIONS = 1024
# Database path
EMBEDDINGS_DB = settings.data_dir / "embeddings.db"
# Programme definitions
PROGRAMME = {
# Wahlprogramme NRW 2022
"spd-nrw-2022": {
"name": "SPD NRW Wahlprogramm 2022",
"typ": "wahlprogramm",
"partei": "SPD",
"bundesland": "NRW",
"pdf": "spd-nrw-2022.pdf",
},
"cdu-nrw-2022": {
"name": "CDU NRW Wahlprogramm 2022",
"typ": "wahlprogramm",
"partei": "CDU",
"bundesland": "NRW",
"pdf": "cdu-nrw-2022.pdf",
},
"gruene-nrw-2022": {
"name": "Grüne NRW Wahlprogramm 2022",
"typ": "wahlprogramm",
"partei": "GRÜNE",
"bundesland": "NRW",
"pdf": "gruene-nrw-2022.pdf",
},
"fdp-nrw-2022": {
"name": "FDP NRW Wahlprogramm 2022",
"typ": "wahlprogramm",
"partei": "FDP",
"bundesland": "NRW",
"pdf": "fdp-nrw-2022.pdf",
},
"afd-nrw-2022": {
"name": "AfD NRW Wahlprogramm 2022",
"typ": "wahlprogramm",
"partei": "AfD",
"bundesland": "NRW",
"pdf": "afd-nrw-2022.pdf",
},
# Grundsatzprogramme (Bund)
"spd-grundsatz": {
"name": "SPD Grundsatzprogramm 2007",
"typ": "parteiprogramm",
"partei": "SPD",
"pdf": "spd-grundsatzprogramm.pdf",
},
"cdu-grundsatz": {
"name": "CDU Grundsatzprogramm 2007",
"typ": "parteiprogramm",
"partei": "CDU",
"pdf": "cdu-grundsatzprogramm.pdf",
},
"gruene-grundsatz": {
"name": "Grüne Grundsatzprogramm 2020",
"typ": "parteiprogramm",
"partei": "GRÜNE",
"pdf": "gruene-grundsatzprogramm.pdf",
},
"fdp-grundsatz": {
"name": "FDP Grundsatzprogramm 2012",
"typ": "parteiprogramm",
"partei": "FDP",
"pdf": "fdp-grundsatzprogramm.pdf",
},
}
def init_embeddings_db():
"""Initialize the embeddings database."""
conn = sqlite3.connect(EMBEDDINGS_DB)
conn.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY,
programm_id TEXT NOT NULL,
partei TEXT NOT NULL,
typ TEXT NOT NULL,
seite INTEGER,
text TEXT NOT NULL,
embedding BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_chunks_partei ON chunks(partei)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_chunks_typ ON chunks(typ)")
conn.commit()
conn.close()
def get_client() -> OpenAI:
"""Get DashScope client."""
return OpenAI(
api_key=settings.dashscope_api_key,
base_url=settings.dashscope_base_url,
)
def create_embedding(text: str) -> list[float]:
"""Create embedding for text using Qwen."""
client = get_client()
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=text,
dimensions=EMBEDDING_DIMENSIONS,
)
return response.data[0].embedding
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""Split text into overlapping chunks by words."""
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk_words = words[i:i + chunk_size]
chunk = " ".join(chunk_words)
if chunk.strip():
chunks.append(chunk)
i += chunk_size - overlap
return chunks
def extract_text_with_pages(pdf_path: Path) -> list[tuple[int, str]]:
"""Extract text from PDF with page numbers."""
doc = fitz.open(pdf_path)
pages = []
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text()
if text.strip():
pages.append((page_num + 1, text))
doc.close()
return pages
def index_programm(programm_id: str, pdf_dir: Path) -> int:
"""Index a single program PDF into embeddings database."""
if programm_id not in PROGRAMME:
raise ValueError(f"Unknown program: {programm_id}")
info = PROGRAMME[programm_id]
pdf_path = pdf_dir / info["pdf"]
if not pdf_path.exists():
print(f"PDF not found: {pdf_path}")
return 0
conn = sqlite3.connect(EMBEDDINGS_DB)
# Remove existing chunks for this program
conn.execute("DELETE FROM chunks WHERE programm_id = ?", (programm_id,))
# Extract and chunk
pages = extract_text_with_pages(pdf_path)
total_chunks = 0
for page_num, page_text in pages:
chunks = chunk_text(page_text, chunk_size=400, overlap=50)
for chunk_text_content in chunks:
if len(chunk_text_content.split()) < 20: # Skip tiny chunks
continue
try:
embedding = create_embedding(chunk_text_content)
embedding_blob = json.dumps(embedding).encode()
conn.execute("""
INSERT INTO chunks (programm_id, partei, typ, seite, text, embedding)
VALUES (?, ?, ?, ?, ?, ?)
""", (
programm_id,
info["partei"],
info["typ"],
page_num,
chunk_text_content,
embedding_blob,
))
total_chunks += 1
except Exception as e:
print(f"Error embedding chunk: {e}")
continue
conn.commit()
conn.close()
print(f"Indexed {total_chunks} chunks from {programm_id}")
return total_chunks
def cosine_similarity(a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def find_relevant_chunks(
query: str,
parteien: list[str] = None,
typ: str = None,
top_k: int = 3,
min_similarity: float = 0.5,
) -> list[dict]:
"""Find most relevant chunks for a query."""
query_embedding = create_embedding(query)
conn = sqlite3.connect(EMBEDDINGS_DB)
conn.row_factory = sqlite3.Row
# Build query
sql = "SELECT * FROM chunks WHERE 1=1"
params = []
if parteien:
placeholders = ",".join("?" * len(parteien))
sql += f" AND partei IN ({placeholders})"
params.extend(parteien)
if typ:
sql += " AND typ = ?"
params.append(typ)
rows = conn.execute(sql, params).fetchall()
conn.close()
# Calculate similarities
results = []
for row in rows:
chunk_embedding = json.loads(row["embedding"])
similarity = cosine_similarity(query_embedding, chunk_embedding)
if similarity >= min_similarity:
results.append({
"programm_id": row["programm_id"],
"partei": row["partei"],
"typ": row["typ"],
"seite": row["seite"],
"text": row["text"],
"similarity": similarity,
})
# Sort by similarity and return top_k
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:top_k]
def get_relevant_quotes_for_antrag(
antrag_text: str,
fraktionen: list[str],
top_k_per_partei: int = 2,
) -> dict[str, list[dict]]:
"""Get relevant quotes from Wahl- and Parteiprogramme for an Antrag."""
results = {}
for partei in fraktionen + ["CDU", "GRÜNE"]: # Include Regierungsfraktionen
partei_upper = partei.upper() if partei != "GRÜNE" else "GRÜNE"
# Wahlprogramm
wahl_chunks = find_relevant_chunks(
antrag_text,
parteien=[partei_upper],
typ="wahlprogramm",
top_k=top_k_per_partei,
min_similarity=0.45,
)
# Parteiprogramm
partei_chunks = find_relevant_chunks(
antrag_text,
parteien=[partei_upper],
typ="parteiprogramm",
top_k=top_k_per_partei,
min_similarity=0.45,
)
if wahl_chunks or partei_chunks:
results[partei_upper] = {
"wahlprogramm": wahl_chunks,
"parteiprogramm": partei_chunks,
}
return results
def format_quotes_for_prompt(quotes: dict) -> str:
"""Format quotes for inclusion in LLM prompt."""
if not quotes:
return ""
lines = ["\n## Relevante Passagen aus Wahl- und Parteiprogrammen\n"]
for partei, data in quotes.items():
lines.append(f"\n### {partei}\n")
if data.get("wahlprogramm"):
lines.append("**Wahlprogramm NRW 2022:**")
for chunk in data["wahlprogramm"]:
text = chunk["text"][:500] + "..." if len(chunk["text"]) > 500 else chunk["text"]
lines.append(f'- S. {chunk["seite"]}: "{text}"')
if data.get("parteiprogramm"):
lines.append("\n**Grundsatzprogramm:**")
for chunk in data["parteiprogramm"]:
text = chunk["text"][:500] + "..." if len(chunk["text"]) > 500 else chunk["text"]
lines.append(f'- S. {chunk["seite"]}: "{text}"')
return "\n".join(lines)
def get_programme_info() -> list[dict]:
"""Get list of all indexed programmes with metadata."""
info_list = []
for prog_id, info in PROGRAMME.items():
info_list.append({
"id": prog_id,
"name": info["name"],
"typ": info["typ"],
"partei": info["partei"],
"bundesland": info.get("bundesland"),
"pdf": info["pdf"],
"pdf_url": f"/static/referenzen/{info['pdf']}",
})
return info_list
def get_indexing_status() -> dict:
"""Get status of indexed programmes."""
if not EMBEDDINGS_DB.exists():
return {"indexed": 0, "programmes": []}
conn = sqlite3.connect(EMBEDDINGS_DB)
# Count chunks per program
rows = conn.execute("""
SELECT programm_id, COUNT(*) as chunks
FROM chunks
GROUP BY programm_id
""").fetchall()
conn.close()
indexed = {row[0]: row[1] for row in rows}
programmes = []
for prog_id, info in PROGRAMME.items():
programmes.append({
"id": prog_id,
"name": info["name"],
"partei": info["partei"],
"chunks": indexed.get(prog_id, 0),
"indexed": prog_id in indexed,
})
return {
"indexed": len(indexed),
"total": len(PROGRAMME),
"programmes": programmes,
}