import os
import psycopg2
from dotenv import load_dotenv
from openai import OpenAI


load_dotenv()

OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small")
OPENAI_EMBEDDING_DIMENSIONS = int(os.getenv("OPENAI_EMBEDDING_DIMENSIONS", "384"))
OPENAI_REQUEST_BATCH_SIZE = int(os.getenv("OPENAI_REQUEST_BATCH_SIZE", "512"))
SYNC_EVERY_REQUESTS = int(os.getenv("SYNC_EVERY_REQUESTS", "10"))

if OPENAI_REQUEST_BATCH_SIZE <= 0:
    raise ValueError("OPENAI_REQUEST_BATCH_SIZE must be greater than zero.")

if SYNC_EVERY_REQUESTS <= 0:
    raise ValueError("SYNC_EVERY_REQUESTS must be greater than zero.")

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


conn = psycopg2.connect(
    host=os.getenv("DB_HOST"),
    port=os.getenv("DB_PORT"),
    dbname=os.getenv("DB_NAME"),
    user=os.getenv("DB_USER"),
    password=os.getenv("DB_PASSWORD")
)


def vector_to_pgvector(vector):
    return "[" + ",".join(str(x) for x in vector) + "]"


def chunks(items, size):
    for index in range(0, len(items), size):
        yield items[index:index + size]


def get_embeddings(texts):
    response = client.embeddings.create(
        model=OPENAI_EMBEDDING_MODEL,
        input=texts,
        dimensions=OPENAI_EMBEDDING_DIMENSIONS,
        encoding_format="float",
    )

    return [
        item.embedding
        for item in sorted(response.data, key=lambda item: item.index)
    ]


def update_embeddings(cur, batch):
    property_ids = [property_id for property_id, _ in batch]
    texts = []

    for property_id, embedding_text in batch:
        text = str(embedding_text or "").strip()
        if not text:
            raise ValueError(f"Property {property_id} has empty embedding text.")
        texts.append(text)

    embeddings = get_embeddings(texts)

    for property_id, embedding in zip(property_ids, embeddings):
        cur.execute("""
            UPDATE properties
            SET embedding = %s::vector,
                updated_at = CURRENT_TIMESTAMP
            WHERE property_id = %s
        """, (vector_to_pgvector(embedding), property_id))

    return len(batch)


try:
    with conn.cursor() as cur:
        cur.execute("""
            SELECT property_id, embedding_text
            FROM vw_property_embedding_text
            WHERE property_id IN (
                SELECT property_id
                FROM properties
                WHERE embedding IS NULL
            )
        """)

        rows = cur.fetchall()

        print(f"Found {len(rows)} properties without embeddings.")
        print(
            "Generating OpenAI embeddings with "
            f"{OPENAI_EMBEDDING_MODEL} in batches of "
            f"{OPENAI_REQUEST_BATCH_SIZE} texts..."
        )

        pending_sync_count = 0
        pending_request_count = 0

        for batch in chunks(rows, OPENAI_REQUEST_BATCH_SIZE):
            print(f"Requesting {len(batch)} embeddings from OpenAI...")
            pending_sync_count += update_embeddings(cur, batch)
            pending_request_count += 1

            if pending_request_count == SYNC_EVERY_REQUESTS:
                conn.commit()
                print(f"Synced {pending_sync_count} embeddings.")
                pending_sync_count = 0
                pending_request_count = 0

        if pending_request_count > 0:
            conn.commit()
            print(f"Synced {pending_sync_count} embeddings.")

        print("Done. Embeddings saved successfully.")

finally:
    conn.close()
