| 1 | import psycopg2
|
|---|
| 2 | import numpy as np
|
|---|
| 3 | from psycopg2.extras import execute_batch
|
|---|
| 4 |
|
|---|
| 5 | BATCH_SIZE = 5000
|
|---|
| 6 | COMMIT_EVERY = 5
|
|---|
| 7 |
|
|---|
| 8 | conn = psycopg2.connect(
|
|---|
| 9 | host="localhost",
|
|---|
| 10 | database="tcs_local",
|
|---|
| 11 | user="postgres",
|
|---|
| 12 | password="postgres123"
|
|---|
| 13 | )
|
|---|
| 14 |
|
|---|
| 15 | read_cursor = conn.cursor()
|
|---|
| 16 | write_cursor = conn.cursor()
|
|---|
| 17 |
|
|---|
| 18 | read_cursor.execute("""
|
|---|
| 19 | WITH ranked_tasks AS
|
|---|
| 20 | (
|
|---|
| 21 | SELECT
|
|---|
| 22 | o.worker_id,
|
|---|
| 23 | o.task_request_id,
|
|---|
| 24 | ROW_NUMBER()
|
|---|
| 25 | OVER
|
|---|
| 26 | (
|
|---|
| 27 | PARTITION BY o.worker_id
|
|---|
| 28 | ORDER BY
|
|---|
| 29 | COALESCE(
|
|---|
| 30 | t.updated_at,
|
|---|
| 31 | t.created_at
|
|---|
| 32 | ) DESC
|
|---|
| 33 | ) AS rn
|
|---|
| 34 | FROM Task t
|
|---|
| 35 | JOIN Offer o
|
|---|
| 36 | ON o.id = t.offer_id
|
|---|
| 37 | WHERE
|
|---|
| 38 | t.status = 'COMPLETED'
|
|---|
| 39 | )
|
|---|
| 40 | SELECT
|
|---|
| 41 | rt.worker_id,
|
|---|
| 42 | tre.embedding
|
|---|
| 43 | FROM ranked_tasks rt
|
|---|
| 44 | JOIN task_request_embeddings tre
|
|---|
| 45 | ON tre.task_request_id =
|
|---|
| 46 | rt.task_request_id
|
|---|
| 47 | WHERE
|
|---|
| 48 | rt.rn <= 3
|
|---|
| 49 | ORDER BY
|
|---|
| 50 | rt.worker_id;
|
|---|
| 51 | """)
|
|---|
| 52 |
|
|---|
| 53 | rows = read_cursor.fetchall()
|
|---|
| 54 |
|
|---|
| 55 | processed_workers = 0
|
|---|
| 56 | commit_counter = 0
|
|---|
| 57 |
|
|---|
| 58 | current_worker = None
|
|---|
| 59 | current_vectors = []
|
|---|
| 60 |
|
|---|
| 61 | worker_profiles = []
|
|---|
| 62 |
|
|---|
| 63 | for worker_id, embedding in rows:
|
|---|
| 64 |
|
|---|
| 65 | if current_worker is not None and worker_id != current_worker:
|
|---|
| 66 |
|
|---|
| 67 | avg_vector = np.mean(current_vectors, axis=0)
|
|---|
| 68 |
|
|---|
| 69 | vector_string = (
|
|---|
| 70 | "[" +
|
|---|
| 71 | ",".join(map(str, avg_vector.tolist())) +
|
|---|
| 72 | "]"
|
|---|
| 73 | )
|
|---|
| 74 |
|
|---|
| 75 | worker_profiles.append(
|
|---|
| 76 | (current_worker, vector_string)
|
|---|
| 77 | )
|
|---|
| 78 |
|
|---|
| 79 | processed_workers += 1
|
|---|
| 80 | current_vectors = []
|
|---|
| 81 |
|
|---|
| 82 | if len(worker_profiles) >= BATCH_SIZE:
|
|---|
| 83 |
|
|---|
| 84 | execute_batch(
|
|---|
| 85 | write_cursor,
|
|---|
| 86 | """
|
|---|
| 87 | INSERT INTO worker_recommendation_profiles
|
|---|
| 88 | (
|
|---|
| 89 | worker_id,
|
|---|
| 90 | preference_embedding,
|
|---|
| 91 | updated_at
|
|---|
| 92 | )
|
|---|
| 93 | VALUES
|
|---|
| 94 | (
|
|---|
| 95 | %s,
|
|---|
| 96 | %s::vector,
|
|---|
| 97 | CURRENT_TIMESTAMP
|
|---|
| 98 | )
|
|---|
| 99 | ON CONFLICT (worker_id)
|
|---|
| 100 | DO UPDATE SET
|
|---|
| 101 | preference_embedding =
|
|---|
| 102 | EXCLUDED.preference_embedding,
|
|---|
| 103 | updated_at =
|
|---|
| 104 | CURRENT_TIMESTAMP
|
|---|
| 105 | """,
|
|---|
| 106 | worker_profiles,
|
|---|
| 107 | page_size=1000
|
|---|
| 108 | )
|
|---|
| 109 |
|
|---|
| 110 | worker_profiles = []
|
|---|
| 111 | commit_counter += 1
|
|---|
| 112 |
|
|---|
| 113 | if commit_counter >= COMMIT_EVERY:
|
|---|
| 114 | conn.commit()
|
|---|
| 115 | commit_counter = 0
|
|---|
| 116 |
|
|---|
| 117 | current_worker = worker_id
|
|---|
| 118 |
|
|---|
| 119 | vector = np.fromstring(
|
|---|
| 120 | embedding.strip("[]"),
|
|---|
| 121 | sep=",",
|
|---|
| 122 | dtype=np.float32
|
|---|
| 123 | )
|
|---|
| 124 |
|
|---|
| 125 | current_vectors.append(vector)
|
|---|
| 126 |
|
|---|
| 127 | if len(current_vectors) > 0:
|
|---|
| 128 | avg_vector = np.mean(current_vectors, axis=0)
|
|---|
| 129 | vector_string = (
|
|---|
| 130 | "[" +
|
|---|
| 131 | ",".join(map(str, avg_vector.tolist())) +
|
|---|
| 132 | "]"
|
|---|
| 133 | )
|
|---|
| 134 |
|
|---|
| 135 | worker_profiles.append(
|
|---|
| 136 | (current_worker, vector_string)
|
|---|
| 137 | )
|
|---|
| 138 |
|
|---|
| 139 | if len(worker_profiles) > 0:
|
|---|
| 140 | execute_batch(
|
|---|
| 141 | write_cursor,
|
|---|
| 142 | """
|
|---|
| 143 | INSERT INTO worker_recommendation_profiles
|
|---|
| 144 | (
|
|---|
| 145 | worker_id,
|
|---|
| 146 | preference_embedding,
|
|---|
| 147 | updated_at
|
|---|
| 148 | )
|
|---|
| 149 | VALUES
|
|---|
| 150 | (
|
|---|
| 151 | %s,
|
|---|
| 152 | %s::vector,
|
|---|
| 153 | CURRENT_TIMESTAMP
|
|---|
| 154 | )
|
|---|
| 155 | ON CONFLICT (worker_id)
|
|---|
| 156 | DO UPDATE SET
|
|---|
| 157 | preference_embedding =
|
|---|
| 158 | EXCLUDED.preference_embedding,
|
|---|
| 159 | updated_at =
|
|---|
| 160 | CURRENT_TIMESTAMP
|
|---|
| 161 | """,
|
|---|
| 162 | worker_profiles,
|
|---|
| 163 | page_size=1000
|
|---|
| 164 | )
|
|---|
| 165 |
|
|---|
| 166 | conn.commit()
|
|---|
| 167 |
|
|---|
| 168 | read_cursor.close()
|
|---|
| 169 | write_cursor.close()
|
|---|
| 170 | conn.close() |
|---|