275 lines
9.3 KiB
Python
275 lines
9.3 KiB
Python
|
|
"""FastAPI backend for podcast-mindmap."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import numpy as np
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional
|
||
|
|
from fastapi import FastAPI, Query, HTTPException
|
||
|
|
from fastapi.staticfiles import StaticFiles
|
||
|
|
from fastapi.responses import FileResponse
|
||
|
|
|
||
|
|
from database import get_db, init_db, get_all_embeddings
|
||
|
|
|
||
|
|
app = FastAPI(title="Podcast Mindmap API")
|
||
|
|
|
||
|
|
DATA_DIR = os.environ.get("DATA_DIR", "/data")
|
||
|
|
AUDIO_DIR = os.environ.get("AUDIO_DIR", "/audio")
|
||
|
|
STATIC_DIR = os.environ.get("STATIC_DIR", "/static")
|
||
|
|
|
||
|
|
# Cache embeddings in memory
|
||
|
|
_embeddings_cache = {}
|
||
|
|
|
||
|
|
|
||
|
|
def _load_embeddings(podcast_id: Optional[str] = None):
|
||
|
|
"""Load and cache embeddings."""
|
||
|
|
key = podcast_id or "__all__"
|
||
|
|
if key not in _embeddings_cache:
|
||
|
|
vectors, meta = get_all_embeddings(podcast_id)
|
||
|
|
_embeddings_cache[key] = (vectors, meta)
|
||
|
|
return _embeddings_cache[key]
|
||
|
|
|
||
|
|
|
||
|
|
def _invalidate_cache():
|
||
|
|
_embeddings_cache.clear()
|
||
|
|
|
||
|
|
|
||
|
|
# ── API Routes ──
|
||
|
|
|
||
|
|
@app.get("/api/podcasts")
|
||
|
|
def list_podcasts():
|
||
|
|
db = get_db()
|
||
|
|
rows = db.execute("SELECT * FROM podcasts").fetchall()
|
||
|
|
db.close()
|
||
|
|
return [dict(r) for r in rows]
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/podcasts/{podcast_id}")
|
||
|
|
def get_podcast(podcast_id: str):
|
||
|
|
db = get_db()
|
||
|
|
podcast = db.execute("SELECT * FROM podcasts WHERE id = ?", (podcast_id,)).fetchone()
|
||
|
|
if not podcast:
|
||
|
|
raise HTTPException(404, "Podcast not found")
|
||
|
|
|
||
|
|
staffeln = db.execute("SELECT * FROM staffeln WHERE podcast_id = ? ORDER BY id", (podcast_id,)).fetchall()
|
||
|
|
themes = db.execute("SELECT * FROM themes WHERE podcast_id = ?", (podcast_id,)).fetchall()
|
||
|
|
episodes = db.execute("SELECT * FROM episodes WHERE podcast_id = ? ORDER BY id", (podcast_id,)).fetchall()
|
||
|
|
quotes = db.execute("SELECT * FROM quotes WHERE podcast_id = ?", (podcast_id,)).fetchall()
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
# Build mindmap_data compatible format
|
||
|
|
return {
|
||
|
|
"name": podcast["name"],
|
||
|
|
"host": podcast["host"],
|
||
|
|
"description": podcast["description"],
|
||
|
|
"staffeln": [dict(s) for s in staffeln],
|
||
|
|
"themes": [{**dict(t), "episodes": json.loads(t["episodes_json"])} for t in themes],
|
||
|
|
"episodes": [{"id": e["id"], "title": e["title"], "guest": e["guest"],
|
||
|
|
"staffel": e["staffel"], "audioFile": e["audio_file"]} for e in episodes],
|
||
|
|
"quotes": [{
|
||
|
|
"id": q["id"], "text": q["text"], "verbatim": q["verbatim"],
|
||
|
|
"speaker": q["speaker"], "episode": q["episode_id"],
|
||
|
|
"startTime": q["start_time"], "endTime": q["end_time"],
|
||
|
|
"isTopQuote": bool(q["is_top_quote"]),
|
||
|
|
"themes": json.loads(q["themes_json"]),
|
||
|
|
"audioFile": next((e["audio_file"] for e in episodes if e["id"] == q["episode_id"]), None)
|
||
|
|
} for q in quotes],
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/podcasts/{podcast_id}/transcript/{episode_id}")
|
||
|
|
def get_transcript(podcast_id: str, episode_id: str):
|
||
|
|
db = get_db()
|
||
|
|
paras = db.execute(
|
||
|
|
"SELECT idx, start_time, end_time, text FROM paragraphs WHERE podcast_id = ? AND episode_id = ? ORDER BY idx",
|
||
|
|
(podcast_id, episode_id)
|
||
|
|
).fetchall()
|
||
|
|
db.close()
|
||
|
|
return {"paragraphs": [{"start": p["start_time"], "end": p["end_time"], "text": p["text"]} for p in paras]}
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/search")
|
||
|
|
def search(q: str = Query(..., min_length=2), podcast_id: Optional[str] = None, limit: int = 50):
|
||
|
|
"""Full-text search across all transcripts."""
|
||
|
|
db = get_db()
|
||
|
|
q_like = f"%{q}%"
|
||
|
|
|
||
|
|
if podcast_id:
|
||
|
|
rows = db.execute(
|
||
|
|
"SELECT p.podcast_id, p.episode_id, p.idx, p.start_time, p.text, e.title, e.guest "
|
||
|
|
"FROM paragraphs p JOIN episodes e ON p.podcast_id = e.podcast_id AND p.episode_id = e.id "
|
||
|
|
"WHERE p.podcast_id = ? AND p.text LIKE ? LIMIT ?",
|
||
|
|
(podcast_id, q_like, limit)
|
||
|
|
).fetchall()
|
||
|
|
else:
|
||
|
|
rows = db.execute(
|
||
|
|
"SELECT p.podcast_id, p.episode_id, p.idx, p.start_time, p.text, e.title, e.guest "
|
||
|
|
"FROM paragraphs p JOIN episodes e ON p.podcast_id = e.podcast_id AND p.episode_id = e.id "
|
||
|
|
"WHERE p.text LIKE ? LIMIT ?",
|
||
|
|
(q_like, limit)
|
||
|
|
).fetchall()
|
||
|
|
|
||
|
|
db.close()
|
||
|
|
return [dict(r) for r in rows]
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/similar/{podcast_id}/{episode_id}/{para_idx}")
|
||
|
|
def find_similar(podcast_id: str, episode_id: str, para_idx: int,
|
||
|
|
limit: int = 10, cross_podcast: bool = False):
|
||
|
|
"""Find semantically similar paragraphs using embeddings."""
|
||
|
|
db = get_db()
|
||
|
|
row = db.execute(
|
||
|
|
"SELECT id, embedding FROM paragraphs WHERE podcast_id = ? AND episode_id = ? AND idx = ?",
|
||
|
|
(podcast_id, episode_id, para_idx)
|
||
|
|
).fetchone()
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
if not row or not row["embedding"]:
|
||
|
|
raise HTTPException(404, "Paragraph not found or not embedded")
|
||
|
|
|
||
|
|
query_vec = np.frombuffer(row["embedding"], dtype=np.float32)
|
||
|
|
query_vec = query_vec / np.linalg.norm(query_vec)
|
||
|
|
|
||
|
|
# Load all embeddings
|
||
|
|
search_podcast = None if cross_podcast else podcast_id
|
||
|
|
vectors, meta = _load_embeddings(search_podcast)
|
||
|
|
|
||
|
|
if vectors is None or len(meta) == 0:
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Cosine similarity (vectors are already normalized)
|
||
|
|
scores = vectors @ query_vec
|
||
|
|
|
||
|
|
# Get top results (skip self)
|
||
|
|
indices = np.argsort(scores)[::-1]
|
||
|
|
results = []
|
||
|
|
for idx in indices:
|
||
|
|
m = meta[idx]
|
||
|
|
# Skip self
|
||
|
|
if m["podcast_id"] == podcast_id and m["episode_id"] == episode_id and m["idx"] == para_idx:
|
||
|
|
continue
|
||
|
|
# Skip same episode unless cross_podcast
|
||
|
|
if not cross_podcast and m["episode_id"] == episode_id:
|
||
|
|
continue
|
||
|
|
|
||
|
|
results.append({
|
||
|
|
"podcast_id": m["podcast_id"],
|
||
|
|
"episode_id": m["episode_id"],
|
||
|
|
"paragraph_idx": m["idx"],
|
||
|
|
"score": float(scores[idx])
|
||
|
|
})
|
||
|
|
|
||
|
|
if len(results) >= limit:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Enrich with text previews
|
||
|
|
db = get_db()
|
||
|
|
for r in results:
|
||
|
|
p = db.execute(
|
||
|
|
"SELECT text, start_time FROM paragraphs WHERE podcast_id = ? AND episode_id = ? AND idx = ?",
|
||
|
|
(r["podcast_id"], r["episode_id"], r["paragraph_idx"])
|
||
|
|
).fetchone()
|
||
|
|
if p:
|
||
|
|
r["text_preview"] = p["text"][:150]
|
||
|
|
r["start_time"] = p["start_time"]
|
||
|
|
|
||
|
|
ep = db.execute(
|
||
|
|
"SELECT title, guest FROM episodes WHERE podcast_id = ? AND id = ?",
|
||
|
|
(r["podcast_id"], r["episode_id"])
|
||
|
|
).fetchone()
|
||
|
|
if ep:
|
||
|
|
r["episode_title"] = ep["title"]
|
||
|
|
r["guest"] = ep["guest"]
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/semantic-search")
|
||
|
|
def semantic_search(q: str = Query(..., min_length=3), podcast_id: Optional[str] = None, limit: int = 20):
|
||
|
|
"""Semantic search using query embedding."""
|
||
|
|
from embeddings import embed_texts
|
||
|
|
|
||
|
|
try:
|
||
|
|
query_vec = np.array(embed_texts([q])[0], dtype=np.float32)
|
||
|
|
query_vec = query_vec / np.linalg.norm(query_vec)
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(500, f"Embedding failed: {e}")
|
||
|
|
|
||
|
|
vectors, meta = _load_embeddings(podcast_id)
|
||
|
|
if vectors is None:
|
||
|
|
return []
|
||
|
|
|
||
|
|
scores = vectors @ query_vec
|
||
|
|
indices = np.argsort(scores)[::-1][:limit]
|
||
|
|
|
||
|
|
db = get_db()
|
||
|
|
results = []
|
||
|
|
for idx in indices:
|
||
|
|
m = meta[idx]
|
||
|
|
score = float(scores[idx])
|
||
|
|
if score < 0.3:
|
||
|
|
break
|
||
|
|
|
||
|
|
p = db.execute(
|
||
|
|
"SELECT text, start_time FROM paragraphs WHERE id = ?", (m["id"],)
|
||
|
|
).fetchone()
|
||
|
|
ep = db.execute(
|
||
|
|
"SELECT title, guest FROM episodes WHERE podcast_id = ? AND id = ?",
|
||
|
|
(m["podcast_id"], m["episode_id"])
|
||
|
|
).fetchone()
|
||
|
|
|
||
|
|
results.append({
|
||
|
|
"podcast_id": m["podcast_id"],
|
||
|
|
"episode_id": m["episode_id"],
|
||
|
|
"paragraph_idx": m["idx"],
|
||
|
|
"score": score,
|
||
|
|
"text_preview": p["text"][:200] if p else "",
|
||
|
|
"start_time": p["start_time"] if p else None,
|
||
|
|
"episode_title": ep["title"] if ep else "",
|
||
|
|
"guest": ep["guest"] if ep else "",
|
||
|
|
})
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
# ── Startup ──
|
||
|
|
|
||
|
|
@app.on_event("startup")
|
||
|
|
def startup():
|
||
|
|
init_db()
|
||
|
|
|
||
|
|
# Auto-import podcasts from data directory
|
||
|
|
data_path = Path(DATA_DIR)
|
||
|
|
if data_path.exists():
|
||
|
|
for podcast_dir in data_path.iterdir():
|
||
|
|
if not podcast_dir.is_dir():
|
||
|
|
continue
|
||
|
|
mindmap_file = podcast_dir / "mindmap_data.json"
|
||
|
|
srt_file = podcast_dir / "srt_index.json"
|
||
|
|
if mindmap_file.exists() and srt_file.exists():
|
||
|
|
podcast_id = podcast_dir.name
|
||
|
|
db = get_db()
|
||
|
|
existing = db.execute("SELECT id FROM podcasts WHERE id = ?", (podcast_id,)).fetchone()
|
||
|
|
db.close()
|
||
|
|
if not existing:
|
||
|
|
print(f"Importing podcast: {podcast_id}")
|
||
|
|
with open(mindmap_file) as f:
|
||
|
|
mindmap_data = json.load(f)
|
||
|
|
with open(srt_file) as f:
|
||
|
|
srt_index = json.load(f)
|
||
|
|
from database import import_podcast
|
||
|
|
import_podcast(podcast_id, mindmap_data, srt_index)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Static Files + Audio ──
|
||
|
|
|
||
|
|
# Mount audio directory (per-podcast subdirs)
|
||
|
|
if os.path.isdir(AUDIO_DIR):
|
||
|
|
app.mount("/audio", StaticFiles(directory=AUDIO_DIR), name="audio")
|
||
|
|
|
||
|
|
# Serve webapp as static files (fallback)
|
||
|
|
if os.path.isdir(STATIC_DIR):
|
||
|
|
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")
|