388 lines
11 KiB
Python
388 lines
11 KiB
Python
|
|
"""Semantic search for Wahlprogramme and Parteiprogramme using Qwen embeddings."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import sqlite3
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import fitz # PyMuPDF
|
||
|
|
from openai import OpenAI
|
||
|
|
|
||
|
|
from .config import settings
|
||
|
|
|
||
|
|
# Embedding model
|
||
|
|
EMBEDDING_MODEL = "text-embedding-v3"
|
||
|
|
EMBEDDING_DIMENSIONS = 1024
|
||
|
|
|
||
|
|
# Database path
|
||
|
|
EMBEDDINGS_DB = settings.data_dir / "embeddings.db"
|
||
|
|
|
||
|
|
# Programme definitions
|
||
|
|
PROGRAMME = {
|
||
|
|
# Wahlprogramme NRW 2022
|
||
|
|
"spd-nrw-2022": {
|
||
|
|
"name": "SPD NRW Wahlprogramm 2022",
|
||
|
|
"typ": "wahlprogramm",
|
||
|
|
"partei": "SPD",
|
||
|
|
"bundesland": "NRW",
|
||
|
|
"pdf": "spd-nrw-2022.pdf",
|
||
|
|
},
|
||
|
|
"cdu-nrw-2022": {
|
||
|
|
"name": "CDU NRW Wahlprogramm 2022",
|
||
|
|
"typ": "wahlprogramm",
|
||
|
|
"partei": "CDU",
|
||
|
|
"bundesland": "NRW",
|
||
|
|
"pdf": "cdu-nrw-2022.pdf",
|
||
|
|
},
|
||
|
|
"gruene-nrw-2022": {
|
||
|
|
"name": "Grüne NRW Wahlprogramm 2022",
|
||
|
|
"typ": "wahlprogramm",
|
||
|
|
"partei": "GRÜNE",
|
||
|
|
"bundesland": "NRW",
|
||
|
|
"pdf": "gruene-nrw-2022.pdf",
|
||
|
|
},
|
||
|
|
"fdp-nrw-2022": {
|
||
|
|
"name": "FDP NRW Wahlprogramm 2022",
|
||
|
|
"typ": "wahlprogramm",
|
||
|
|
"partei": "FDP",
|
||
|
|
"bundesland": "NRW",
|
||
|
|
"pdf": "fdp-nrw-2022.pdf",
|
||
|
|
},
|
||
|
|
"afd-nrw-2022": {
|
||
|
|
"name": "AfD NRW Wahlprogramm 2022",
|
||
|
|
"typ": "wahlprogramm",
|
||
|
|
"partei": "AfD",
|
||
|
|
"bundesland": "NRW",
|
||
|
|
"pdf": "afd-nrw-2022.pdf",
|
||
|
|
},
|
||
|
|
# Grundsatzprogramme (Bund)
|
||
|
|
"spd-grundsatz": {
|
||
|
|
"name": "SPD Grundsatzprogramm 2007",
|
||
|
|
"typ": "parteiprogramm",
|
||
|
|
"partei": "SPD",
|
||
|
|
"pdf": "spd-grundsatzprogramm.pdf",
|
||
|
|
},
|
||
|
|
"cdu-grundsatz": {
|
||
|
|
"name": "CDU Grundsatzprogramm 2007",
|
||
|
|
"typ": "parteiprogramm",
|
||
|
|
"partei": "CDU",
|
||
|
|
"pdf": "cdu-grundsatzprogramm.pdf",
|
||
|
|
},
|
||
|
|
"gruene-grundsatz": {
|
||
|
|
"name": "Grüne Grundsatzprogramm 2020",
|
||
|
|
"typ": "parteiprogramm",
|
||
|
|
"partei": "GRÜNE",
|
||
|
|
"pdf": "gruene-grundsatzprogramm.pdf",
|
||
|
|
},
|
||
|
|
"fdp-grundsatz": {
|
||
|
|
"name": "FDP Grundsatzprogramm 2012",
|
||
|
|
"typ": "parteiprogramm",
|
||
|
|
"partei": "FDP",
|
||
|
|
"pdf": "fdp-grundsatzprogramm.pdf",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def init_embeddings_db():
|
||
|
|
"""Initialize the embeddings database."""
|
||
|
|
conn = sqlite3.connect(EMBEDDINGS_DB)
|
||
|
|
conn.execute("""
|
||
|
|
CREATE TABLE IF NOT EXISTS chunks (
|
||
|
|
id INTEGER PRIMARY KEY,
|
||
|
|
programm_id TEXT NOT NULL,
|
||
|
|
partei TEXT NOT NULL,
|
||
|
|
typ TEXT NOT NULL,
|
||
|
|
seite INTEGER,
|
||
|
|
text TEXT NOT NULL,
|
||
|
|
embedding BLOB NOT NULL,
|
||
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
|
|
)
|
||
|
|
""")
|
||
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_chunks_partei ON chunks(partei)")
|
||
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_chunks_typ ON chunks(typ)")
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
def get_client() -> OpenAI:
|
||
|
|
"""Get DashScope client."""
|
||
|
|
return OpenAI(
|
||
|
|
api_key=settings.dashscope_api_key,
|
||
|
|
base_url=settings.dashscope_base_url,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def create_embedding(text: str) -> list[float]:
|
||
|
|
"""Create embedding for text using Qwen."""
|
||
|
|
client = get_client()
|
||
|
|
response = client.embeddings.create(
|
||
|
|
model=EMBEDDING_MODEL,
|
||
|
|
input=text,
|
||
|
|
dimensions=EMBEDDING_DIMENSIONS,
|
||
|
|
)
|
||
|
|
return response.data[0].embedding
|
||
|
|
|
||
|
|
|
||
|
|
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
||
|
|
"""Split text into overlapping chunks by words."""
|
||
|
|
words = text.split()
|
||
|
|
chunks = []
|
||
|
|
|
||
|
|
i = 0
|
||
|
|
while i < len(words):
|
||
|
|
chunk_words = words[i:i + chunk_size]
|
||
|
|
chunk = " ".join(chunk_words)
|
||
|
|
if chunk.strip():
|
||
|
|
chunks.append(chunk)
|
||
|
|
i += chunk_size - overlap
|
||
|
|
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
|
||
|
|
def extract_text_with_pages(pdf_path: Path) -> list[tuple[int, str]]:
|
||
|
|
"""Extract text from PDF with page numbers."""
|
||
|
|
doc = fitz.open(pdf_path)
|
||
|
|
pages = []
|
||
|
|
|
||
|
|
for page_num in range(len(doc)):
|
||
|
|
page = doc[page_num]
|
||
|
|
text = page.get_text()
|
||
|
|
if text.strip():
|
||
|
|
pages.append((page_num + 1, text))
|
||
|
|
|
||
|
|
doc.close()
|
||
|
|
return pages
|
||
|
|
|
||
|
|
|
||
|
|
def index_programm(programm_id: str, pdf_dir: Path) -> int:
|
||
|
|
"""Index a single program PDF into embeddings database."""
|
||
|
|
if programm_id not in PROGRAMME:
|
||
|
|
raise ValueError(f"Unknown program: {programm_id}")
|
||
|
|
|
||
|
|
info = PROGRAMME[programm_id]
|
||
|
|
pdf_path = pdf_dir / info["pdf"]
|
||
|
|
|
||
|
|
if not pdf_path.exists():
|
||
|
|
print(f"PDF not found: {pdf_path}")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
conn = sqlite3.connect(EMBEDDINGS_DB)
|
||
|
|
|
||
|
|
# Remove existing chunks for this program
|
||
|
|
conn.execute("DELETE FROM chunks WHERE programm_id = ?", (programm_id,))
|
||
|
|
|
||
|
|
# Extract and chunk
|
||
|
|
pages = extract_text_with_pages(pdf_path)
|
||
|
|
total_chunks = 0
|
||
|
|
|
||
|
|
for page_num, page_text in pages:
|
||
|
|
chunks = chunk_text(page_text, chunk_size=400, overlap=50)
|
||
|
|
|
||
|
|
for chunk_text_content in chunks:
|
||
|
|
if len(chunk_text_content.split()) < 20: # Skip tiny chunks
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
embedding = create_embedding(chunk_text_content)
|
||
|
|
embedding_blob = json.dumps(embedding).encode()
|
||
|
|
|
||
|
|
conn.execute("""
|
||
|
|
INSERT INTO chunks (programm_id, partei, typ, seite, text, embedding)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
||
|
|
""", (
|
||
|
|
programm_id,
|
||
|
|
info["partei"],
|
||
|
|
info["typ"],
|
||
|
|
page_num,
|
||
|
|
chunk_text_content,
|
||
|
|
embedding_blob,
|
||
|
|
))
|
||
|
|
total_chunks += 1
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error embedding chunk: {e}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
print(f"Indexed {total_chunks} chunks from {programm_id}")
|
||
|
|
return total_chunks
|
||
|
|
|
||
|
|
|
||
|
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||
|
|
"""Calculate cosine similarity between two vectors."""
|
||
|
|
dot = sum(x * y for x, y in zip(a, b))
|
||
|
|
norm_a = sum(x * x for x in a) ** 0.5
|
||
|
|
norm_b = sum(x * x for x in b) ** 0.5
|
||
|
|
if norm_a == 0 or norm_b == 0:
|
||
|
|
return 0.0
|
||
|
|
return dot / (norm_a * norm_b)
|
||
|
|
|
||
|
|
|
||
|
|
def find_relevant_chunks(
|
||
|
|
query: str,
|
||
|
|
parteien: list[str] = None,
|
||
|
|
typ: str = None,
|
||
|
|
top_k: int = 3,
|
||
|
|
min_similarity: float = 0.5,
|
||
|
|
) -> list[dict]:
|
||
|
|
"""Find most relevant chunks for a query."""
|
||
|
|
|
||
|
|
query_embedding = create_embedding(query)
|
||
|
|
|
||
|
|
conn = sqlite3.connect(EMBEDDINGS_DB)
|
||
|
|
conn.row_factory = sqlite3.Row
|
||
|
|
|
||
|
|
# Build query
|
||
|
|
sql = "SELECT * FROM chunks WHERE 1=1"
|
||
|
|
params = []
|
||
|
|
|
||
|
|
if parteien:
|
||
|
|
placeholders = ",".join("?" * len(parteien))
|
||
|
|
sql += f" AND partei IN ({placeholders})"
|
||
|
|
params.extend(parteien)
|
||
|
|
|
||
|
|
if typ:
|
||
|
|
sql += " AND typ = ?"
|
||
|
|
params.append(typ)
|
||
|
|
|
||
|
|
rows = conn.execute(sql, params).fetchall()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
# Calculate similarities
|
||
|
|
results = []
|
||
|
|
for row in rows:
|
||
|
|
chunk_embedding = json.loads(row["embedding"])
|
||
|
|
similarity = cosine_similarity(query_embedding, chunk_embedding)
|
||
|
|
|
||
|
|
if similarity >= min_similarity:
|
||
|
|
results.append({
|
||
|
|
"programm_id": row["programm_id"],
|
||
|
|
"partei": row["partei"],
|
||
|
|
"typ": row["typ"],
|
||
|
|
"seite": row["seite"],
|
||
|
|
"text": row["text"],
|
||
|
|
"similarity": similarity,
|
||
|
|
})
|
||
|
|
|
||
|
|
# Sort by similarity and return top_k
|
||
|
|
results.sort(key=lambda x: x["similarity"], reverse=True)
|
||
|
|
return results[:top_k]
|
||
|
|
|
||
|
|
|
||
|
|
def get_relevant_quotes_for_antrag(
|
||
|
|
antrag_text: str,
|
||
|
|
fraktionen: list[str],
|
||
|
|
top_k_per_partei: int = 2,
|
||
|
|
) -> dict[str, list[dict]]:
|
||
|
|
"""Get relevant quotes from Wahl- and Parteiprogramme for an Antrag."""
|
||
|
|
|
||
|
|
results = {}
|
||
|
|
|
||
|
|
for partei in fraktionen + ["CDU", "GRÜNE"]: # Include Regierungsfraktionen
|
||
|
|
partei_upper = partei.upper() if partei != "GRÜNE" else "GRÜNE"
|
||
|
|
|
||
|
|
# Wahlprogramm
|
||
|
|
wahl_chunks = find_relevant_chunks(
|
||
|
|
antrag_text,
|
||
|
|
parteien=[partei_upper],
|
||
|
|
typ="wahlprogramm",
|
||
|
|
top_k=top_k_per_partei,
|
||
|
|
min_similarity=0.45,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Parteiprogramm
|
||
|
|
partei_chunks = find_relevant_chunks(
|
||
|
|
antrag_text,
|
||
|
|
parteien=[partei_upper],
|
||
|
|
typ="parteiprogramm",
|
||
|
|
top_k=top_k_per_partei,
|
||
|
|
min_similarity=0.45,
|
||
|
|
)
|
||
|
|
|
||
|
|
if wahl_chunks or partei_chunks:
|
||
|
|
results[partei_upper] = {
|
||
|
|
"wahlprogramm": wahl_chunks,
|
||
|
|
"parteiprogramm": partei_chunks,
|
||
|
|
}
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def format_quotes_for_prompt(quotes: dict) -> str:
|
||
|
|
"""Format quotes for inclusion in LLM prompt."""
|
||
|
|
if not quotes:
|
||
|
|
return ""
|
||
|
|
|
||
|
|
lines = ["\n## Relevante Passagen aus Wahl- und Parteiprogrammen\n"]
|
||
|
|
|
||
|
|
for partei, data in quotes.items():
|
||
|
|
lines.append(f"\n### {partei}\n")
|
||
|
|
|
||
|
|
if data.get("wahlprogramm"):
|
||
|
|
lines.append("**Wahlprogramm NRW 2022:**")
|
||
|
|
for chunk in data["wahlprogramm"]:
|
||
|
|
text = chunk["text"][:500] + "..." if len(chunk["text"]) > 500 else chunk["text"]
|
||
|
|
lines.append(f'- S. {chunk["seite"]}: "{text}"')
|
||
|
|
|
||
|
|
if data.get("parteiprogramm"):
|
||
|
|
lines.append("\n**Grundsatzprogramm:**")
|
||
|
|
for chunk in data["parteiprogramm"]:
|
||
|
|
text = chunk["text"][:500] + "..." if len(chunk["text"]) > 500 else chunk["text"]
|
||
|
|
lines.append(f'- S. {chunk["seite"]}: "{text}"')
|
||
|
|
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
def get_programme_info() -> list[dict]:
|
||
|
|
"""Get list of all indexed programmes with metadata."""
|
||
|
|
info_list = []
|
||
|
|
|
||
|
|
for prog_id, info in PROGRAMME.items():
|
||
|
|
info_list.append({
|
||
|
|
"id": prog_id,
|
||
|
|
"name": info["name"],
|
||
|
|
"typ": info["typ"],
|
||
|
|
"partei": info["partei"],
|
||
|
|
"bundesland": info.get("bundesland"),
|
||
|
|
"pdf": info["pdf"],
|
||
|
|
"pdf_url": f"/static/referenzen/{info['pdf']}",
|
||
|
|
})
|
||
|
|
|
||
|
|
return info_list
|
||
|
|
|
||
|
|
|
||
|
|
def get_indexing_status() -> dict:
|
||
|
|
"""Get status of indexed programmes."""
|
||
|
|
if not EMBEDDINGS_DB.exists():
|
||
|
|
return {"indexed": 0, "programmes": []}
|
||
|
|
|
||
|
|
conn = sqlite3.connect(EMBEDDINGS_DB)
|
||
|
|
|
||
|
|
# Count chunks per program
|
||
|
|
rows = conn.execute("""
|
||
|
|
SELECT programm_id, COUNT(*) as chunks
|
||
|
|
FROM chunks
|
||
|
|
GROUP BY programm_id
|
||
|
|
""").fetchall()
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
indexed = {row[0]: row[1] for row in rows}
|
||
|
|
|
||
|
|
programmes = []
|
||
|
|
for prog_id, info in PROGRAMME.items():
|
||
|
|
programmes.append({
|
||
|
|
"id": prog_id,
|
||
|
|
"name": info["name"],
|
||
|
|
"partei": info["partei"],
|
||
|
|
"chunks": indexed.get(prog_id, 0),
|
||
|
|
"indexed": prog_id in indexed,
|
||
|
|
})
|
||
|
|
|
||
|
|
return {
|
||
|
|
"indexed": len(indexed),
|
||
|
|
"total": len(PROGRAMME),
|
||
|
|
"programmes": programmes,
|
||
|
|
}
|