AdvancedTopics: embed_taskrequests.py

File embed_taskrequests.py, 4.1 KB (added by 231141, 36 hours ago)
Line 
1import psycopg2
2from sentence_transformers import SentenceTransformer
3from psycopg2.extras import execute_batch
4
5BATCH_SIZE = 5000
6COMMIT_EVERY = 5
7
8model = SentenceTransformer( "all-MiniLM-L6-v2")
9conn = psycopg2.connect(
10 host="localhost",
11 database="tcs_local",
12 user="postgres",
13 password="postgres123"
14)
15
16read_cursor = conn.cursor()
17write_cursor = conn.cursor()
18
19# We created a view that gives a string that represents each task request
20# CREATE OR REPLACE VIEW vw_task_request_embedding_text AS
21# SELECT
22# tr.id AS task_request_id,
23#
24# CONCAT(
25# 'Category: ', c.category_name, '. ',
26# 'Description: ', COALESCE(tr.description, ''), '. ',
27# 'Work mode: ', tr.work_mode, '. ',
28# 'Location: ', l.city, '. '
29# ) AS embedding_text
30#
31# FROM TaskRequest tr
32# JOIN Category c
33# ON c.id = tr.category_id
34# JOIN Location l
35# ON l.id = tr.location_id;
36
37read_cursor.execute("""
38WITH ranked_tasks AS
39(
40 SELECT
41 o.worker_id,
42 o.task_request_id,
43 ROW_NUMBER()
44 OVER
45 (
46 PARTITION BY o.worker_id
47 ORDER BY
48 COALESCE(
49 t.updated_at,
50 t.created_at
51 ) DESC
52 ) AS rn
53 FROM Task t
54 JOIN Offer o
55 ON o.id = t.offer_id
56 WHERE
57 t.status = 'COMPLETED'
58)
59SELECT DISTINCT
60 v.task_request_id,
61 v.embedding_text
62FROM ranked_tasks rt
63JOIN
64 vw_task_request_embedding_text v
65 ON v.task_request_id =
66 rt.task_request_id
67LEFT JOIN
68 task_request_embeddings te
69 ON te.task_request_id =
70 rt.task_request_id
71WHERE
72 rt.rn <= 3
73 AND te.task_request_id IS NULL
74ORDER BY
75 v.task_request_id
76""")
77
78# Variant 2 - we used the same script for embedding the open taskrequests which we will reccomend with this change in the script.
79# SELECT
80# v.task_request_id,
81# v.embedding_text
82# FROM vw_task_request_embedding_text v
83#
84# JOIN TaskRequest tr
85# ON tr.id = v.task_request_id
86#
87# LEFT JOIN task_request_embeddings te
88# ON te.task_request_id = v.task_request_id
89#
90# WHERE
91# tr.status = 'OPEN'
92# AND te.task_request_id IS NULL
93#
94# ORDER BY
95# v.task_request_id;
96
97rows = read_cursor.fetchall()
98total_rows = len(rows)
99print(f"Found {total_rows} task requests without embeddings.")
100print()
101
102processed = 0
103batch_counter = 0
104commit_counter = 0
105
106for start in range(0, total_rows, BATCH_SIZE):
107 batch = rows[ start: start + BATCH_SIZE ]
108 ids = []
109 texts = []
110
111 for (task_request_id, embedding_text) in batch:
112 text = str(embedding_text or "").strip()
113 if not text:
114 continue
115 ids.append( task_request_id)
116 texts.append(text)
117 if len(texts) == 0:
118 continue
119
120 batch_counter += 1
121
122 vectors = model.encode(
123 texts,
124 convert_to_numpy=True,
125 normalize_embeddings=True,
126 show_progress_bar=False
127 )
128
129 insert_data = []
130 for (task_request_id, vector) in zip(ids, vectors):
131 vector_string = ("[" + ",".join(map(str,vector.tolist())) + "]" )
132 insert_data.append((task_request_id, vector_string))
133
134 execute_batch(
135 write_cursor,
136 """
137 INSERT INTO
138 task_request_embeddings
139 (
140 task_request_id,
141 embedding,
142 embedded_at
143 )
144 VALUES
145 (
146 %s,
147 %s::vector,
148 CURRENT_TIMESTAMP
149 )
150 ON CONFLICT
151 (
152 task_request_id
153 )
154 DO UPDATE SET
155 embedding =
156 EXCLUDED.embedding,
157 embedded_at =
158 CURRENT_TIMESTAMP
159 """,
160 insert_data,
161 page_size=1000
162 )
163 commit_counter += 1
164 if commit_counter >= COMMIT_EVERY:
165 conn.commit()
166 commit_counter = 0
167if commit_counter > 0:
168 conn.commit()
169
170print( f"Total embeddings: "f"{processed}")
171
172read_cursor.close()
173write_cursor.close()
174conn.close()