import psycopg2
import numpy as np
from psycopg2.extras import execute_batch

BATCH_SIZE = 5000
COMMIT_EVERY = 5

conn = psycopg2.connect(
    host="localhost",
    database="tcs_local",
    user="postgres",
    password="postgres123"
)

read_cursor = conn.cursor()
write_cursor = conn.cursor()

read_cursor.execute("""
WITH ranked_tasks AS
(
    SELECT
        o.worker_id,
        o.task_request_id,
        ROW_NUMBER()
        OVER
        (
            PARTITION BY o.worker_id
            ORDER BY
                COALESCE(
                    t.updated_at,
                    t.created_at
                ) DESC
        ) AS rn
    FROM Task t
    JOIN Offer o
        ON o.id = t.offer_id
    WHERE
        t.status = 'COMPLETED'
)
SELECT
    rt.worker_id,
    tre.embedding
FROM ranked_tasks rt
JOIN task_request_embeddings tre
    ON tre.task_request_id =
       rt.task_request_id
WHERE
    rt.rn <= 3
ORDER BY
    rt.worker_id;
""")

rows = read_cursor.fetchall()

processed_workers = 0
commit_counter = 0

current_worker = None
current_vectors = []

worker_profiles = []

for worker_id, embedding in rows:

    if current_worker is not None and worker_id != current_worker:

        avg_vector = np.mean(current_vectors, axis=0)

        vector_string = (
            "[" +
            ",".join(map(str, avg_vector.tolist())) +
            "]"
        )

        worker_profiles.append(
            (current_worker, vector_string)
        )

        processed_workers += 1
        current_vectors = []

        if len(worker_profiles) >= BATCH_SIZE:

            execute_batch(
                write_cursor,
                """
                INSERT INTO worker_recommendation_profiles
                (
                    worker_id,
                    preference_embedding,
                    updated_at
                )
                VALUES
                (
                    %s,
                    %s::vector,
                    CURRENT_TIMESTAMP
                )
                ON CONFLICT (worker_id)
                DO UPDATE SET
                    preference_embedding =
                        EXCLUDED.preference_embedding,
                    updated_at =
                        CURRENT_TIMESTAMP
                """,
                worker_profiles,
                page_size=1000
            )

            worker_profiles = []
            commit_counter += 1

            if commit_counter >= COMMIT_EVERY:
                conn.commit()
                commit_counter = 0

    current_worker = worker_id

    vector = np.fromstring(
        embedding.strip("[]"),
        sep=",",
        dtype=np.float32
    )

    current_vectors.append(vector)

if len(current_vectors) > 0:
    avg_vector = np.mean(current_vectors, axis=0)
    vector_string = (
        "[" +
        ",".join(map(str, avg_vector.tolist())) +
        "]"
    )

    worker_profiles.append(
        (current_worker, vector_string)
    )

if len(worker_profiles) > 0:
    execute_batch(
        write_cursor,
        """
        INSERT INTO worker_recommendation_profiles
        (
            worker_id,
            preference_embedding,
            updated_at
        )
        VALUES
        (
            %s,
            %s::vector,
            CURRENT_TIMESTAMP
        )
        ON CONFLICT (worker_id)
        DO UPDATE SET
            preference_embedding =
                EXCLUDED.preference_embedding,
            updated_at =
                CURRENT_TIMESTAMP
        """,
        worker_profiles,
        page_size=1000
    )

conn.commit()

read_cursor.close()
write_cursor.close()
conn.close()