podcast-mindmap/backend/app.py

401 lines
15 KiB
Python
Raw Normal View History

"""FastAPI backend for podcast-mindmap."""
import json
import os
import numpy as np
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, Query, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from database import get_db, init_db, get_all_embeddings
app = FastAPI(title="Podcast Mindmap API")
DATA_DIR = os.environ.get("DATA_DIR", "/data")
AUDIO_DIR = os.environ.get("AUDIO_DIR", "/audio")
STATIC_DIR = os.environ.get("STATIC_DIR", "/static")
# Cache embeddings in memory
_embeddings_cache = {}
def _load_embeddings(podcast_id: Optional[str] = None):
"""Load and cache embeddings."""
key = podcast_id or "__all__"
if key not in _embeddings_cache:
vectors, meta = get_all_embeddings(podcast_id)
_embeddings_cache[key] = (vectors, meta)
return _embeddings_cache[key]
def _invalidate_cache():
_embeddings_cache.clear()
# ── API Routes ──
@app.get("/api/podcasts")
def list_podcasts():
db = get_db()
rows = db.execute("SELECT * FROM podcasts").fetchall()
db.close()
return [dict(r) for r in rows]
@app.get("/api/podcasts/{podcast_id}")
def get_podcast(podcast_id: str):
db = get_db()
podcast = db.execute("SELECT * FROM podcasts WHERE id = ?", (podcast_id,)).fetchone()
if not podcast:
raise HTTPException(404, "Podcast not found")
staffeln = db.execute("SELECT * FROM staffeln WHERE podcast_id = ? ORDER BY id", (podcast_id,)).fetchall()
themes = db.execute("SELECT * FROM themes WHERE podcast_id = ?", (podcast_id,)).fetchall()
episodes = db.execute("SELECT * FROM episodes WHERE podcast_id = ? ORDER BY id", (podcast_id,)).fetchall()
quotes = db.execute("SELECT * FROM quotes WHERE podcast_id = ?", (podcast_id,)).fetchall()
db.close()
# Build mindmap_data compatible format
return {
"name": podcast["name"],
"host": podcast["host"],
"description": podcast["description"],
"staffeln": [dict(s) for s in staffeln],
"themes": [{**dict(t), "episodes": json.loads(t["episodes_json"])} for t in themes],
"episodes": [{"id": e["id"], "title": e["title"], "guest": e["guest"],
"staffel": e["staffel"], "audioFile": e["audio_file"]} for e in episodes],
"quotes": [{
"id": q["id"], "text": q["text"], "verbatim": q["verbatim"],
"speaker": q["speaker"], "episode": q["episode_id"],
"startTime": q["start_time"], "endTime": q["end_time"],
"isTopQuote": bool(q["is_top_quote"]),
"themes": json.loads(q["themes_json"]),
"audioFile": next((e["audio_file"] for e in episodes if e["id"] == q["episode_id"]), None)
} for q in quotes],
}
@app.get("/api/podcasts/{podcast_id}/transcript/{episode_id}")
def get_transcript(podcast_id: str, episode_id: str):
db = get_db()
paras = db.execute(
"SELECT idx, start_time, end_time, text FROM paragraphs WHERE podcast_id = ? AND episode_id = ? ORDER BY idx",
(podcast_id, episode_id)
).fetchall()
db.close()
return {"paragraphs": [{"start": p["start_time"], "end": p["end_time"], "text": p["text"]} for p in paras]}
@app.get("/api/podcasts/{podcast_id}/transcript/{episode_id}/words")
def get_words(podcast_id: str, episode_id: str):
"""Get word-level timestamps for an episode."""
db = get_db()
# Check if words table exists
try:
words = db.execute(
"SELECT segment_idx, word_idx, word, start_time, end_time FROM words "
"WHERE podcast_id = ? AND episode_id = ? ORDER BY segment_idx, word_idx",
(podcast_id, episode_id)
).fetchall()
except Exception:
db.close()
return {"words": [], "available": False}
db.close()
if not words:
return {"words": [], "available": False}
return {
"available": True,
"words": [{"seg": w["segment_idx"], "idx": w["word_idx"],
"word": w["word"], "start": w["start_time"], "end": w["end_time"]} for w in words]
}
@app.get("/api/search")
def search(q: str = Query(..., min_length=2), podcast_id: Optional[str] = None, limit: int = 50):
"""Full-text search across all transcripts."""
db = get_db()
q_like = f"%{q}%"
if podcast_id:
rows = db.execute(
"SELECT p.podcast_id, p.episode_id, p.idx, p.start_time, p.text, e.title, e.guest "
"FROM paragraphs p JOIN episodes e ON p.podcast_id = e.podcast_id AND p.episode_id = e.id "
"WHERE p.podcast_id = ? AND p.text LIKE ? LIMIT ?",
(podcast_id, q_like, limit)
).fetchall()
else:
rows = db.execute(
"SELECT p.podcast_id, p.episode_id, p.idx, p.start_time, p.text, e.title, e.guest "
"FROM paragraphs p JOIN episodes e ON p.podcast_id = e.podcast_id AND p.episode_id = e.id "
"WHERE p.text LIKE ? LIMIT ?",
(q_like, limit)
).fetchall()
db.close()
return [dict(r) for r in rows]
@app.get("/api/similar/{podcast_id}/{episode_id}/{para_idx}")
def find_similar(podcast_id: str, episode_id: str, para_idx: int,
limit: int = 10, cross_podcast: bool = False):
"""Find semantically similar paragraphs using embeddings."""
db = get_db()
row = db.execute(
"SELECT id, embedding FROM paragraphs WHERE podcast_id = ? AND episode_id = ? AND idx = ?",
(podcast_id, episode_id, para_idx)
).fetchone()
db.close()
if not row or not row["embedding"]:
raise HTTPException(404, "Paragraph not found or not embedded")
query_vec = np.frombuffer(row["embedding"], dtype=np.float32)
query_vec = query_vec / np.linalg.norm(query_vec)
# Load all embeddings
search_podcast = None if cross_podcast else podcast_id
vectors, meta = _load_embeddings(search_podcast)
if vectors is None or len(meta) == 0:
return []
# Cosine similarity (vectors are already normalized)
scores = vectors @ query_vec
# Get top results (skip self)
indices = np.argsort(scores)[::-1]
results = []
for idx in indices:
m = meta[idx]
# Skip self
if m["podcast_id"] == podcast_id and m["episode_id"] == episode_id and m["idx"] == para_idx:
continue
# Skip same episode unless cross_podcast
if not cross_podcast and m["episode_id"] == episode_id:
continue
results.append({
"podcast_id": m["podcast_id"],
"episode_id": m["episode_id"],
"paragraph_idx": m["idx"],
"score": float(scores[idx])
})
if len(results) >= limit:
break
# Enrich with text previews
db = get_db()
for r in results:
p = db.execute(
"SELECT text, start_time FROM paragraphs WHERE podcast_id = ? AND episode_id = ? AND idx = ?",
(r["podcast_id"], r["episode_id"], r["paragraph_idx"])
).fetchone()
if p:
r["text_preview"] = p["text"][:150]
r["start_time"] = p["start_time"]
ep = db.execute(
"SELECT title, guest FROM episodes WHERE podcast_id = ? AND id = ?",
(r["podcast_id"], r["episode_id"])
).fetchone()
if ep:
r["episode_title"] = ep["title"]
r["guest"] = ep["guest"]
db.close()
return results
@app.get("/api/similar-precomputed/{podcast_id}/{episode_id}/{para_idx}")
def get_precomputed_similar(podcast_id: str, episode_id: str, para_idx: int, limit: int = 10):
"""Get precomputed similar paragraphs (fast, no embedding computation)."""
db = get_db()
rows = db.execute(
"SELECT sl.target_podcast, sl.target_episode, sl.target_idx, sl.score, "
"p.text, p.start_time, e.title, e.guest "
"FROM semantic_links sl "
"JOIN paragraphs p ON sl.target_podcast = p.podcast_id AND sl.target_episode = p.episode_id AND sl.target_idx = p.idx "
"JOIN episodes e ON sl.target_podcast = e.podcast_id AND sl.target_episode = e.id "
"WHERE sl.podcast_id = ? AND sl.source_episode = ? AND sl.source_idx = ? "
"ORDER BY sl.score DESC LIMIT ?",
(podcast_id, episode_id, para_idx, limit)
).fetchall()
db.close()
return [{
"podcast_id": r["target_podcast"],
"episode_id": r["target_episode"],
"paragraph_idx": r["target_idx"],
"score": r["score"],
"text_preview": r["text"][:150],
"start_time": r["start_time"],
"episode_title": r["title"],
"guest": r["guest"],
} for r in rows]
@app.get("/api/compare")
def compare_podcasts(a: str = Query(...), b: str = Query(...)):
"""Compare two podcasts: shared topics, stats, cross-links."""
db = get_db()
# Basic stats
stats = {}
for pid in (a, b):
podcast = db.execute("SELECT * FROM podcasts WHERE id = ?", (pid,)).fetchone()
if not podcast:
raise HTTPException(404, f"Podcast '{pid}' not found")
ep_count = db.execute("SELECT COUNT(*) as c FROM episodes WHERE podcast_id = ?", (pid,)).fetchone()["c"]
q_count = db.execute("SELECT COUNT(*) as c FROM quotes WHERE podcast_id = ?", (pid,)).fetchone()["c"]
p_count = db.execute("SELECT COUNT(*) as c FROM paragraphs WHERE podcast_id = ?", (pid,)).fetchone()["c"]
stats[pid] = {"name": podcast["name"], "episodes": ep_count, "quotes": q_count, "paragraphs": p_count}
# Shared topics via topic tags
topics_a = db.execute(
"SELECT DISTINCT t.tag FROM topics t JOIN paragraphs p ON t.paragraph_id = p.id WHERE p.podcast_id = ?", (a,)
).fetchall()
topics_b = db.execute(
"SELECT DISTINCT t.tag FROM topics t JOIN paragraphs p ON t.paragraph_id = p.id WHERE p.podcast_id = ?", (b,)
).fetchall()
set_a = {r["tag"] for r in topics_a}
set_b = {r["tag"] for r in topics_b}
shared = sorted(set_a & set_b)
only_a = sorted(set_a - set_b)
only_b = sorted(set_b - set_a)
# Cross-podcast semantic links count
cross_links = 0
top_links = []
try:
cross_links = db.execute(
"SELECT COUNT(*) as c FROM semantic_links WHERE "
"(podcast_id = ? AND target_podcast = ?) OR (podcast_id = ? AND target_podcast = ?)",
(a, b, b, a)
).fetchone()["c"]
top_links = db.execute(
"SELECT sl.*, p1.text as source_text, p2.text as target_text, "
"e1.title as source_title, e2.title as target_title "
"FROM semantic_links sl "
"JOIN paragraphs p1 ON sl.podcast_id = p1.podcast_id AND sl.source_episode = p1.episode_id AND sl.source_idx = p1.idx "
"JOIN paragraphs p2 ON sl.target_podcast = p2.podcast_id AND sl.target_episode = p2.episode_id AND sl.target_idx = p2.idx "
"JOIN episodes e1 ON sl.podcast_id = e1.podcast_id AND sl.source_episode = e1.id "
"JOIN episodes e2 ON sl.target_podcast = e2.podcast_id AND sl.target_episode = e2.id "
"WHERE (sl.podcast_id = ? AND sl.target_podcast = ?) OR (sl.podcast_id = ? AND sl.target_podcast = ?) "
"ORDER BY sl.score DESC LIMIT 20",
(a, b, b, a)
).fetchall()
except Exception:
pass # semantic_links table may not exist yet
db.close()
return {
"stats": stats,
"shared_topics": shared,
"only_in": {a: only_a, b: only_b},
"cross_links_count": cross_links,
"top_cross_links": [{
"source_podcast": r["podcast_id"], "source_episode": r["source_episode"],
"source_text": r["source_text"][:150], "source_title": r["source_title"],
"target_podcast": r["target_podcast"], "target_episode": r["target_episode"],
"target_text": r["target_text"][:150], "target_title": r["target_title"],
"score": r["score"]
} for r in top_links]
}
@app.get("/api/semantic-search")
def semantic_search(q: str = Query(..., min_length=3), podcast_id: Optional[str] = None, limit: int = 20):
"""Semantic search using query embedding."""
from embeddings import embed_texts
try:
query_vec = np.array(embed_texts([q])[0], dtype=np.float32)
query_vec = query_vec / np.linalg.norm(query_vec)
except Exception as e:
raise HTTPException(500, f"Embedding failed: {e}")
vectors, meta = _load_embeddings(podcast_id)
if vectors is None:
return []
scores = vectors @ query_vec
indices = np.argsort(scores)[::-1][:limit]
db = get_db()
results = []
for idx in indices:
m = meta[idx]
score = float(scores[idx])
if score < 0.3:
break
p = db.execute(
"SELECT text, start_time FROM paragraphs WHERE id = ?", (m["id"],)
).fetchone()
ep = db.execute(
"SELECT title, guest FROM episodes WHERE podcast_id = ? AND id = ?",
(m["podcast_id"], m["episode_id"])
).fetchone()
results.append({
"podcast_id": m["podcast_id"],
"episode_id": m["episode_id"],
"paragraph_idx": m["idx"],
"score": score,
"text_preview": p["text"][:200] if p else "",
"start_time": p["start_time"] if p else None,
"episode_title": ep["title"] if ep else "",
"guest": ep["guest"] if ep else "",
})
db.close()
return results
# ── Startup ──
@app.on_event("startup")
def startup():
init_db()
# Auto-import podcasts from data directory
data_path = Path(DATA_DIR)
if data_path.exists():
for podcast_dir in data_path.iterdir():
if not podcast_dir.is_dir():
continue
mindmap_file = podcast_dir / "mindmap_data.json"
srt_file = podcast_dir / "srt_index.json"
if mindmap_file.exists() and srt_file.exists():
podcast_id = podcast_dir.name
db = get_db()
existing = db.execute("SELECT id FROM podcasts WHERE id = ?", (podcast_id,)).fetchone()
db.close()
if not existing:
print(f"Importing podcast: {podcast_id}")
with open(mindmap_file) as f:
mindmap_data = json.load(f)
with open(srt_file) as f:
srt_index = json.load(f)
from database import import_podcast
import_podcast(podcast_id, mindmap_data, srt_index)
# ── Static Files + Audio ──
# Mount audio directory (per-podcast subdirs)
if os.path.isdir(AUDIO_DIR):
app.mount("/audio", StaticFiles(directory=AUDIO_DIR), name="audio")
# Serve webapp as static files (fallback)
if os.path.isdir(STATIC_DIR):
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")