74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
|
|
"""Generate embeddings via DashScope (Qwen text-embedding-v3)."""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
from openai import OpenAI
|
||
|
|
from database import get_db, store_embedding
|
||
|
|
|
||
|
|
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY", "")
|
||
|
|
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-v3")
|
||
|
|
BATCH_SIZE = 6 # DashScope text-embedding-v3 limit: 10 texts, but long texts need smaller batches
|
||
|
|
|
||
|
|
|
||
|
|
def get_client():
|
||
|
|
return OpenAI(
|
||
|
|
api_key=DASHSCOPE_API_KEY,
|
||
|
|
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def embed_texts(texts: list[str]) -> list[list[float]]:
|
||
|
|
"""Embed a batch of texts."""
|
||
|
|
client = get_client()
|
||
|
|
response = client.embeddings.create(
|
||
|
|
model=EMBEDDING_MODEL,
|
||
|
|
input=texts,
|
||
|
|
dimensions=1024
|
||
|
|
)
|
||
|
|
return [item.embedding for item in response.data]
|
||
|
|
|
||
|
|
|
||
|
|
def embed_all_paragraphs(podcast_id: str = None):
|
||
|
|
"""Embed all paragraphs that don't have embeddings yet."""
|
||
|
|
db = get_db()
|
||
|
|
|
||
|
|
if podcast_id:
|
||
|
|
rows = db.execute(
|
||
|
|
"SELECT id, text FROM paragraphs WHERE podcast_id = ? AND embedding IS NULL",
|
||
|
|
(podcast_id,)
|
||
|
|
).fetchall()
|
||
|
|
else:
|
||
|
|
rows = db.execute(
|
||
|
|
"SELECT id, text FROM paragraphs WHERE embedding IS NULL"
|
||
|
|
).fetchall()
|
||
|
|
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
if not rows:
|
||
|
|
print("No paragraphs to embed.")
|
||
|
|
return
|
||
|
|
|
||
|
|
print(f"Embedding {len(rows)} paragraphs...")
|
||
|
|
|
||
|
|
for i in range(0, len(rows), BATCH_SIZE):
|
||
|
|
batch = rows[i:i + BATCH_SIZE]
|
||
|
|
texts = [r["text"][:2000] for r in batch] # Truncate long texts
|
||
|
|
|
||
|
|
try:
|
||
|
|
embeddings = embed_texts(texts)
|
||
|
|
for row, emb in zip(batch, embeddings):
|
||
|
|
store_embedding(row["id"], emb)
|
||
|
|
print(f" Batch {i // BATCH_SIZE + 1}/{(len(rows) + BATCH_SIZE - 1) // BATCH_SIZE}: {len(batch)} paragraphs")
|
||
|
|
except Exception as e:
|
||
|
|
print(f" Error at batch {i // BATCH_SIZE + 1}: {e}")
|
||
|
|
time.sleep(2)
|
||
|
|
continue
|
||
|
|
|
||
|
|
print("Done.")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import sys
|
||
|
|
podcast_id = sys.argv[1] if len(sys.argv) > 1 else None
|
||
|
|
embed_all_paragraphs(podcast_id)
|