AdvancedTopics: embed_worker_profiles.py

File embed_worker_profiles.py, 3.7 KB (added by 231141, 36 hours ago)
Line 
1import psycopg2
2import numpy as np
3from psycopg2.extras import execute_batch
4
5BATCH_SIZE = 5000
6COMMIT_EVERY = 5
7
8conn = psycopg2.connect(
9 host="localhost",
10 database="tcs_local",
11 user="postgres",
12 password="postgres123"
13)
14
15read_cursor = conn.cursor()
16write_cursor = conn.cursor()
17
18read_cursor.execute("""
19WITH ranked_tasks AS
20(
21 SELECT
22 o.worker_id,
23 o.task_request_id,
24 ROW_NUMBER()
25 OVER
26 (
27 PARTITION BY o.worker_id
28 ORDER BY
29 COALESCE(
30 t.updated_at,
31 t.created_at
32 ) DESC
33 ) AS rn
34 FROM Task t
35 JOIN Offer o
36 ON o.id = t.offer_id
37 WHERE
38 t.status = 'COMPLETED'
39)
40SELECT
41 rt.worker_id,
42 tre.embedding
43FROM ranked_tasks rt
44JOIN task_request_embeddings tre
45 ON tre.task_request_id =
46 rt.task_request_id
47WHERE
48 rt.rn <= 3
49ORDER BY
50 rt.worker_id;
51""")
52
53rows = read_cursor.fetchall()
54
55processed_workers = 0
56commit_counter = 0
57
58current_worker = None
59current_vectors = []
60
61worker_profiles = []
62
63for worker_id, embedding in rows:
64
65 if current_worker is not None and worker_id != current_worker:
66
67 avg_vector = np.mean(current_vectors, axis=0)
68
69 vector_string = (
70 "[" +
71 ",".join(map(str, avg_vector.tolist())) +
72 "]"
73 )
74
75 worker_profiles.append(
76 (current_worker, vector_string)
77 )
78
79 processed_workers += 1
80 current_vectors = []
81
82 if len(worker_profiles) >= BATCH_SIZE:
83
84 execute_batch(
85 write_cursor,
86 """
87 INSERT INTO worker_recommendation_profiles
88 (
89 worker_id,
90 preference_embedding,
91 updated_at
92 )
93 VALUES
94 (
95 %s,
96 %s::vector,
97 CURRENT_TIMESTAMP
98 )
99 ON CONFLICT (worker_id)
100 DO UPDATE SET
101 preference_embedding =
102 EXCLUDED.preference_embedding,
103 updated_at =
104 CURRENT_TIMESTAMP
105 """,
106 worker_profiles,
107 page_size=1000
108 )
109
110 worker_profiles = []
111 commit_counter += 1
112
113 if commit_counter >= COMMIT_EVERY:
114 conn.commit()
115 commit_counter = 0
116
117 current_worker = worker_id
118
119 vector = np.fromstring(
120 embedding.strip("[]"),
121 sep=",",
122 dtype=np.float32
123 )
124
125 current_vectors.append(vector)
126
127if len(current_vectors) > 0:
128 avg_vector = np.mean(current_vectors, axis=0)
129 vector_string = (
130 "[" +
131 ",".join(map(str, avg_vector.tolist())) +
132 "]"
133 )
134
135 worker_profiles.append(
136 (current_worker, vector_string)
137 )
138
139if len(worker_profiles) > 0:
140 execute_batch(
141 write_cursor,
142 """
143 INSERT INTO worker_recommendation_profiles
144 (
145 worker_id,
146 preference_embedding,
147 updated_at
148 )
149 VALUES
150 (
151 %s,
152 %s::vector,
153 CURRENT_TIMESTAMP
154 )
155 ON CONFLICT (worker_id)
156 DO UPDATE SET
157 preference_embedding =
158 EXCLUDED.preference_embedding,
159 updated_at =
160 CURRENT_TIMESTAMP
161 """,
162 worker_profiles,
163 page_size=1000
164 )
165
166conn.commit()
167
168read_cursor.close()
169write_cursor.close()
170conn.close()