podcast-mindmap/backend/precompute.py

92 lines
3.0 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Precompute similarity links between paragraphs and store in DB."""
import os
import sys
import json
import numpy as np
from database import get_db, get_all_embeddings, init_db
MIN_SCORE = float(os.environ.get("SIMILARITY_THRESHOLD", "0.55"))
MAX_LINKS_PER_PARA = 10
def precompute_similarities(podcast_id=None):
"""Compute top-N similar paragraphs for each paragraph and store as JSON."""
vectors, meta = get_all_embeddings(podcast_id)
if vectors is None or len(meta) == 0:
print("No embeddings found.")
return
# Fix NaN vectors
nan_mask = np.isnan(vectors).any(axis=1)
if nan_mask.any():
print(f"Warning: {nan_mask.sum()} NaN vectors found, zeroing them out")
vectors[nan_mask] = 0
n = len(meta)
print(f"Computing similarity matrix for {n} paragraphs...")
# Compute full similarity matrix (n x n)
sim_matrix = vectors @ vectors.T
np.fill_diagonal(sim_matrix, 0) # No self-links
# For each paragraph, find top-N similar from OTHER episodes
db = get_db()
# Create table if needed
db.execute("""
CREATE TABLE IF NOT EXISTS semantic_links (
id INTEGER PRIMARY KEY AUTOINCREMENT,
podcast_id TEXT NOT NULL,
source_episode TEXT NOT NULL,
source_idx INTEGER NOT NULL,
target_podcast TEXT NOT NULL,
target_episode TEXT NOT NULL,
target_idx INTEGER NOT NULL,
score REAL NOT NULL
)
""")
db.execute("CREATE INDEX IF NOT EXISTS idx_semantic_source ON semantic_links(podcast_id, source_episode, source_idx)")
# Clear existing links
if podcast_id:
db.execute("DELETE FROM semantic_links WHERE podcast_id = ?", (podcast_id,))
else:
db.execute("DELETE FROM semantic_links")
total_links = 0
for i in range(n):
scores = sim_matrix[i]
top_indices = np.argsort(scores)[::-1]
links_added = 0
for j in top_indices:
if links_added >= MAX_LINKS_PER_PARA:
break
if scores[j] < MIN_SCORE:
break
# Skip same episode
if meta[i]["episode_id"] == meta[j]["episode_id"] and meta[i]["podcast_id"] == meta[j]["podcast_id"]:
continue
db.execute(
"INSERT INTO semantic_links (podcast_id, source_episode, source_idx, target_podcast, target_episode, target_idx, score) VALUES (?, ?, ?, ?, ?, ?, ?)",
(meta[i]["podcast_id"], meta[i]["episode_id"], meta[i]["idx"],
meta[j]["podcast_id"], meta[j]["episode_id"], meta[j]["idx"],
float(scores[j]))
)
links_added += 1
total_links += 1
db.commit()
db.close()
print(f"Stored {total_links} semantic links (threshold: {MIN_SCORE})")
if __name__ == "__main__":
init_db()
podcast_id = sys.argv[1] if len(sys.argv) > 1 else None
precompute_similarities(podcast_id)