| 1 | import os
|
|---|
| 2 | import psycopg2
|
|---|
| 3 | from dotenv import load_dotenv
|
|---|
| 4 | from openai import OpenAI
|
|---|
| 5 |
|
|---|
| 6 |
|
|---|
| 7 | load_dotenv()
|
|---|
| 8 |
|
|---|
| 9 | OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small")
|
|---|
| 10 | OPENAI_EMBEDDING_DIMENSIONS = int(os.getenv("OPENAI_EMBEDDING_DIMENSIONS", "384"))
|
|---|
| 11 | OPENAI_REQUEST_BATCH_SIZE = int(os.getenv("OPENAI_REQUEST_BATCH_SIZE", "512"))
|
|---|
| 12 | SYNC_EVERY_REQUESTS = int(os.getenv("SYNC_EVERY_REQUESTS", "10"))
|
|---|
| 13 |
|
|---|
| 14 | if OPENAI_REQUEST_BATCH_SIZE <= 0:
|
|---|
| 15 | raise ValueError("OPENAI_REQUEST_BATCH_SIZE must be greater than zero.")
|
|---|
| 16 |
|
|---|
| 17 | if SYNC_EVERY_REQUESTS <= 0:
|
|---|
| 18 | raise ValueError("SYNC_EVERY_REQUESTS must be greater than zero.")
|
|---|
| 19 |
|
|---|
| 20 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|---|
| 21 |
|
|---|
| 22 |
|
|---|
| 23 | conn = psycopg2.connect(
|
|---|
| 24 | host=os.getenv("DB_HOST"),
|
|---|
| 25 | port=os.getenv("DB_PORT"),
|
|---|
| 26 | dbname=os.getenv("DB_NAME"),
|
|---|
| 27 | user=os.getenv("DB_USER"),
|
|---|
| 28 | password=os.getenv("DB_PASSWORD")
|
|---|
| 29 | )
|
|---|
| 30 |
|
|---|
| 31 |
|
|---|
| 32 | def vector_to_pgvector(vector):
|
|---|
| 33 | return "[" + ",".join(str(x) for x in vector) + "]"
|
|---|
| 34 |
|
|---|
| 35 |
|
|---|
| 36 | def chunks(items, size):
|
|---|
| 37 | for index in range(0, len(items), size):
|
|---|
| 38 | yield items[index:index + size]
|
|---|
| 39 |
|
|---|
| 40 |
|
|---|
| 41 | def get_embeddings(texts):
|
|---|
| 42 | response = client.embeddings.create(
|
|---|
| 43 | model=OPENAI_EMBEDDING_MODEL,
|
|---|
| 44 | input=texts,
|
|---|
| 45 | dimensions=OPENAI_EMBEDDING_DIMENSIONS,
|
|---|
| 46 | encoding_format="float",
|
|---|
| 47 | )
|
|---|
| 48 |
|
|---|
| 49 | return [
|
|---|
| 50 | item.embedding
|
|---|
| 51 | for item in sorted(response.data, key=lambda item: item.index)
|
|---|
| 52 | ]
|
|---|
| 53 |
|
|---|
| 54 |
|
|---|
| 55 | def update_embeddings(cur, batch):
|
|---|
| 56 | property_ids = [property_id for property_id, _ in batch]
|
|---|
| 57 | texts = []
|
|---|
| 58 |
|
|---|
| 59 | for property_id, embedding_text in batch:
|
|---|
| 60 | text = str(embedding_text or "").strip()
|
|---|
| 61 | if not text:
|
|---|
| 62 | raise ValueError(f"Property {property_id} has empty embedding text.")
|
|---|
| 63 | texts.append(text)
|
|---|
| 64 |
|
|---|
| 65 | embeddings = get_embeddings(texts)
|
|---|
| 66 |
|
|---|
| 67 | for property_id, embedding in zip(property_ids, embeddings):
|
|---|
| 68 | cur.execute("""
|
|---|
| 69 | UPDATE properties
|
|---|
| 70 | SET embedding = %s::vector,
|
|---|
| 71 | updated_at = CURRENT_TIMESTAMP
|
|---|
| 72 | WHERE property_id = %s
|
|---|
| 73 | """, (vector_to_pgvector(embedding), property_id))
|
|---|
| 74 |
|
|---|
| 75 | return len(batch)
|
|---|
| 76 |
|
|---|
| 77 |
|
|---|
| 78 | try:
|
|---|
| 79 | with conn.cursor() as cur:
|
|---|
| 80 | cur.execute("""
|
|---|
| 81 | SELECT property_id, embedding_text
|
|---|
| 82 | FROM vw_property_embedding_text
|
|---|
| 83 | WHERE property_id IN (
|
|---|
| 84 | SELECT property_id
|
|---|
| 85 | FROM properties
|
|---|
| 86 | WHERE embedding IS NULL
|
|---|
| 87 | )
|
|---|
| 88 | """)
|
|---|
| 89 |
|
|---|
| 90 | rows = cur.fetchall()
|
|---|
| 91 |
|
|---|
| 92 | print(f"Found {len(rows)} properties without embeddings.")
|
|---|
| 93 | print(
|
|---|
| 94 | "Generating OpenAI embeddings with "
|
|---|
| 95 | f"{OPENAI_EMBEDDING_MODEL} in batches of "
|
|---|
| 96 | f"{OPENAI_REQUEST_BATCH_SIZE} texts..."
|
|---|
| 97 | )
|
|---|
| 98 |
|
|---|
| 99 | pending_sync_count = 0
|
|---|
| 100 | pending_request_count = 0
|
|---|
| 101 |
|
|---|
| 102 | for batch in chunks(rows, OPENAI_REQUEST_BATCH_SIZE):
|
|---|
| 103 | print(f"Requesting {len(batch)} embeddings from OpenAI...")
|
|---|
| 104 | pending_sync_count += update_embeddings(cur, batch)
|
|---|
| 105 | pending_request_count += 1
|
|---|
| 106 |
|
|---|
| 107 | if pending_request_count == SYNC_EVERY_REQUESTS:
|
|---|
| 108 | conn.commit()
|
|---|
| 109 | print(f"Synced {pending_sync_count} embeddings.")
|
|---|
| 110 | pending_sync_count = 0
|
|---|
| 111 | pending_request_count = 0
|
|---|
| 112 |
|
|---|
| 113 | if pending_request_count > 0:
|
|---|
| 114 | conn.commit()
|
|---|
| 115 | print(f"Synced {pending_sync_count} embeddings.")
|
|---|
| 116 |
|
|---|
| 117 | print("Done. Embeddings saved successfully.")
|
|---|
| 118 |
|
|---|
| 119 | finally:
|
|---|
| 120 | conn.close()
|
|---|