151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""#8 Standardisiertes Themen-Tagging: cluster Themes ueber Podcasts hinweg.
|
||
|
|
|
||
|
|
Berechnet Embeddings fuer alle Themes (label + description) und ordnet aehnliche
|
||
|
|
Themes ueber Podcast-Grenzen hinweg einem gemeinsamen Cluster zu.
|
||
|
|
|
||
|
|
Output: data/theme_clusters.json
|
||
|
|
{
|
||
|
|
"clusters": [
|
||
|
|
{
|
||
|
|
"id": "klima",
|
||
|
|
"label": "Klima",
|
||
|
|
"members": [
|
||
|
|
{"podcast_id": "ldn", "theme_id": "klima-verkehr", "label": "..."},
|
||
|
|
{"podcast_id": "neu-denken", "theme_id": "klimakrise", "label": "..."}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import sqlite3
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
from openai import OpenAI
|
||
|
|
|
||
|
|
DB_PATH = sys.argv[1] if len(sys.argv) > 1 else "data/db.sqlite"
|
||
|
|
OUT_PATH = sys.argv[2] if len(sys.argv) > 2 else "data/theme_clusters.json"
|
||
|
|
THRESHOLD = float(os.environ.get("THEME_CLUSTER_THRESHOLD", "0.65"))
|
||
|
|
|
||
|
|
API_KEY = os.environ.get("DASHSCOPE_API_KEY", "")
|
||
|
|
BASE_URL = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||
|
|
EMBED_MODEL = "text-embedding-v3"
|
||
|
|
|
||
|
|
|
||
|
|
def embed(client, texts):
|
||
|
|
resp = client.embeddings.create(model=EMBED_MODEL, input=texts, dimensions=1024)
|
||
|
|
return [item.embedding for item in resp.data]
|
||
|
|
|
||
|
|
|
||
|
|
def normalize(v):
|
||
|
|
n = np.linalg.norm(v)
|
||
|
|
return v / n if n else v
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
if not API_KEY:
|
||
|
|
print("DASHSCOPE_API_KEY nicht gesetzt.")
|
||
|
|
sys.exit(1)
|
||
|
|
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||
|
|
db = sqlite3.connect(DB_PATH)
|
||
|
|
db.row_factory = sqlite3.Row
|
||
|
|
|
||
|
|
rows = db.execute("""
|
||
|
|
SELECT podcast_id, id, label, description, color
|
||
|
|
FROM themes ORDER BY podcast_id, id
|
||
|
|
""").fetchall()
|
||
|
|
print(f"Themes gefunden: {len(rows)}")
|
||
|
|
if not rows:
|
||
|
|
print("Nichts zu tun.")
|
||
|
|
return
|
||
|
|
|
||
|
|
texts = []
|
||
|
|
for r in rows:
|
||
|
|
snippet = r["label"]
|
||
|
|
if r["description"]:
|
||
|
|
snippet += " — " + r["description"]
|
||
|
|
texts.append(snippet[:500])
|
||
|
|
|
||
|
|
# Batch limit der DashScope-API ist 10 Texte je Call
|
||
|
|
embs = []
|
||
|
|
for i in range(0, len(texts), 8):
|
||
|
|
embs.extend(embed(client, texts[i:i + 8]))
|
||
|
|
vectors = np.array([normalize(np.array(e, dtype=np.float32)) for e in embs])
|
||
|
|
|
||
|
|
# Single-Linkage Clustering ueber THRESHOLD
|
||
|
|
n = len(rows)
|
||
|
|
parent = list(range(n))
|
||
|
|
|
||
|
|
def find(x):
|
||
|
|
while parent[x] != x:
|
||
|
|
parent[x] = parent[parent[x]]
|
||
|
|
x = parent[x]
|
||
|
|
return x
|
||
|
|
|
||
|
|
def union(a, b):
|
||
|
|
ra, rb = find(a), find(b)
|
||
|
|
if ra != rb:
|
||
|
|
parent[rb] = ra
|
||
|
|
|
||
|
|
sim = vectors @ vectors.T
|
||
|
|
for i in range(n):
|
||
|
|
for j in range(i + 1, n):
|
||
|
|
# Nur ueber Podcast-Grenzen oder bei sehr hoher Aehnlichkeit clustern
|
||
|
|
if rows[i]["podcast_id"] != rows[j]["podcast_id"] and sim[i, j] >= THRESHOLD:
|
||
|
|
union(i, j)
|
||
|
|
|
||
|
|
# Cluster bilden
|
||
|
|
cluster_map = {}
|
||
|
|
for i in range(n):
|
||
|
|
cluster_map.setdefault(find(i), []).append(i)
|
||
|
|
|
||
|
|
clusters = []
|
||
|
|
for cid, idxs in cluster_map.items():
|
||
|
|
members = []
|
||
|
|
for i in idxs:
|
||
|
|
r = rows[i]
|
||
|
|
members.append({
|
||
|
|
"podcast_id": r["podcast_id"],
|
||
|
|
"theme_id": r["id"],
|
||
|
|
"label": r["label"],
|
||
|
|
"color": r["color"],
|
||
|
|
})
|
||
|
|
# Cluster-Label: kuerzeste Member-Bezeichnung
|
||
|
|
cluster_label = min((m["label"] for m in members), key=len)
|
||
|
|
if len(members) > 1:
|
||
|
|
# Cluster-ID aus erstem Theme-ID
|
||
|
|
cid_str = members[0]["theme_id"]
|
||
|
|
else:
|
||
|
|
cid_str = members[0]["theme_id"]
|
||
|
|
clusters.append({
|
||
|
|
"id": cid_str,
|
||
|
|
"label": cluster_label,
|
||
|
|
"n_members": len(members),
|
||
|
|
"is_cross": len(set(m["podcast_id"] for m in members)) > 1,
|
||
|
|
"members": members,
|
||
|
|
})
|
||
|
|
|
||
|
|
# Sortieren: Cross-Cluster zuerst, dann nach Mitgliederzahl
|
||
|
|
clusters.sort(key=lambda c: (-c["is_cross"], -c["n_members"]))
|
||
|
|
|
||
|
|
out = {"threshold": THRESHOLD, "clusters": clusters}
|
||
|
|
Path(OUT_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
with open(OUT_PATH, "w", encoding="utf-8") as f:
|
||
|
|
json.dump(out, f, ensure_ascii=False, indent=2)
|
||
|
|
print(f"\n{len(clusters)} Cluster geschrieben nach {OUT_PATH}")
|
||
|
|
cross = [c for c in clusters if c["is_cross"]]
|
||
|
|
print(f" davon cross-podcast: {len(cross)}")
|
||
|
|
for c in cross:
|
||
|
|
members = ", ".join(f"{m['podcast_id']}/{m['theme_id']}" for m in c["members"])
|
||
|
|
print(f" [{c['label']}] {members}")
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|