import psycopg2
from sentence_transformers import SentenceTransformer
from psycopg2.extras import execute_batch

BATCH_SIZE = 5000
COMMIT_EVERY = 5

model = SentenceTransformer( "all-MiniLM-L6-v2")
conn = psycopg2.connect(
    host="localhost",
    database="tcs_local",
    user="postgres",
    password="postgres123"
)

read_cursor = conn.cursor()
write_cursor = conn.cursor()

# We created a view that gives a string that represents each task request
# CREATE OR REPLACE VIEW vw_task_request_embedding_text AS
# SELECT
#     tr.id AS task_request_id,
#
#     CONCAT(
#         'Category: ', c.category_name, '. ',
#         'Description: ', COALESCE(tr.description, ''), '. ',
#         'Work mode: ', tr.work_mode, '. ',
#         'Location: ', l.city, '. '
#     ) AS embedding_text
#
# FROM TaskRequest tr
# JOIN Category c
#     ON c.id = tr.category_id
# JOIN Location l
#     ON l.id = tr.location_id;

read_cursor.execute("""
WITH ranked_tasks AS
(
    SELECT
        o.worker_id,
        o.task_request_id,
        ROW_NUMBER()
        OVER
        (
            PARTITION BY o.worker_id
            ORDER BY
                COALESCE(
                    t.updated_at,
                    t.created_at
                ) DESC
        ) AS rn
    FROM Task t
    JOIN Offer o
        ON o.id = t.offer_id
    WHERE
        t.status = 'COMPLETED'
)
SELECT DISTINCT
    v.task_request_id,
    v.embedding_text
FROM ranked_tasks rt
JOIN
    vw_task_request_embedding_text v
        ON v.task_request_id =
           rt.task_request_id
LEFT JOIN
    task_request_embeddings te
        ON te.task_request_id =
           rt.task_request_id
WHERE
    rt.rn <= 3
    AND te.task_request_id IS NULL
ORDER BY
    v.task_request_id
""")

# Variant 2 - we used the same script for embedding the open taskrequests which we will reccomend with this change in the script.
# SELECT
#     v.task_request_id,
#     v.embedding_text
# FROM vw_task_request_embedding_text v
#
# JOIN TaskRequest tr
#     ON tr.id = v.task_request_id
#
# LEFT JOIN task_request_embeddings te
#     ON te.task_request_id = v.task_request_id
#
# WHERE
#     tr.status = 'OPEN'
#     AND te.task_request_id IS NULL
#
# ORDER BY
#     v.task_request_id;

rows = read_cursor.fetchall()
total_rows = len(rows)
print(f"Found {total_rows} task requests without embeddings.")
print()

processed = 0
batch_counter = 0
commit_counter = 0

for start in range(0, total_rows, BATCH_SIZE):
    batch = rows[ start: start + BATCH_SIZE ]
    ids = []
    texts = []

    for (task_request_id, embedding_text) in batch:
        text = str(embedding_text or "").strip()
        if not text:
            continue
        ids.append( task_request_id)
        texts.append(text)
    if len(texts) == 0:
        continue

    batch_counter += 1

    vectors = model.encode(
        texts,
        convert_to_numpy=True,
        normalize_embeddings=True,
        show_progress_bar=False
    )

    insert_data = []
    for (task_request_id, vector) in zip(ids, vectors):
        vector_string = ("[" + ",".join(map(str,vector.tolist())) + "]" )
        insert_data.append((task_request_id, vector_string))

    execute_batch(
        write_cursor,
        """
        INSERT INTO
        task_request_embeddings
        (
            task_request_id,
            embedding,
            embedded_at
        )
        VALUES
        (
            %s,
            %s::vector,
            CURRENT_TIMESTAMP
        )
        ON CONFLICT
        (
            task_request_id
        )
        DO UPDATE SET
            embedding =
                EXCLUDED.embedding,
            embedded_at =
                CURRENT_TIMESTAMP
        """,
        insert_data,
        page_size=1000
    )
    commit_counter += 1
    if commit_counter >= COMMIT_EVERY:
        conn.commit()
        commit_counter = 0
if commit_counter > 0:
    conn.commit()

print( f"Total embeddings: "f"{processed}")

read_cursor.close()
write_cursor.close()
conn.close()
