podcast-mindmap/backend/embeddings.py

74 lines
2.1 KiB
Python
Raw Permalink Normal View History

"""Generate embeddings via DashScope (Qwen text-embedding-v3)."""
import os
import time
from openai import OpenAI
from database import get_db, store_embedding
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY", "")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-v3")
BATCH_SIZE = 6 # DashScope text-embedding-v3 limit: 10 texts, but long texts need smaller batches
def get_client():
return OpenAI(
api_key=DASHSCOPE_API_KEY,
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
)
def embed_texts(texts: list[str]) -> list[list[float]]:
"""Embed a batch of texts."""
client = get_client()
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts,
dimensions=1024
)
return [item.embedding for item in response.data]
def embed_all_paragraphs(podcast_id: str = None):
"""Embed all paragraphs that don't have embeddings yet."""
db = get_db()
if podcast_id:
rows = db.execute(
"SELECT id, text FROM paragraphs WHERE podcast_id = ? AND embedding IS NULL",
(podcast_id,)
).fetchall()
else:
rows = db.execute(
"SELECT id, text FROM paragraphs WHERE embedding IS NULL"
).fetchall()
db.close()
if not rows:
print("No paragraphs to embed.")
return
print(f"Embedding {len(rows)} paragraphs...")
for i in range(0, len(rows), BATCH_SIZE):
batch = rows[i:i + BATCH_SIZE]
texts = [r["text"][:2000] for r in batch] # Truncate long texts
try:
embeddings = embed_texts(texts)
for row, emb in zip(batch, embeddings):
store_embedding(row["id"], emb)
print(f" Batch {i // BATCH_SIZE + 1}/{(len(rows) + BATCH_SIZE - 1) // BATCH_SIZE}: {len(batch)} paragraphs")
except Exception as e:
print(f" Error at batch {i // BATCH_SIZE + 1}: {e}")
time.sleep(2)
continue
print("Done.")
if __name__ == "__main__":
import sys
podcast_id = sys.argv[1] if len(sys.argv) > 1 else None
embed_all_paragraphs(podcast_id)