| 1 | import psycopg2
|
|---|
| 2 | from sentence_transformers import SentenceTransformer
|
|---|
| 3 | from psycopg2.extras import execute_batch
|
|---|
| 4 |
|
|---|
| 5 | BATCH_SIZE = 5000
|
|---|
| 6 | COMMIT_EVERY = 5
|
|---|
| 7 |
|
|---|
| 8 | model = SentenceTransformer( "all-MiniLM-L6-v2")
|
|---|
| 9 | conn = psycopg2.connect(
|
|---|
| 10 | host="localhost",
|
|---|
| 11 | database="tcs_local",
|
|---|
| 12 | user="postgres",
|
|---|
| 13 | password="postgres123"
|
|---|
| 14 | )
|
|---|
| 15 |
|
|---|
| 16 | read_cursor = conn.cursor()
|
|---|
| 17 | write_cursor = conn.cursor()
|
|---|
| 18 |
|
|---|
| 19 | # We created a view that gives a string that represents each task request
|
|---|
| 20 | # CREATE OR REPLACE VIEW vw_task_request_embedding_text AS
|
|---|
| 21 | # SELECT
|
|---|
| 22 | # tr.id AS task_request_id,
|
|---|
| 23 | #
|
|---|
| 24 | # CONCAT(
|
|---|
| 25 | # 'Category: ', c.category_name, '. ',
|
|---|
| 26 | # 'Description: ', COALESCE(tr.description, ''), '. ',
|
|---|
| 27 | # 'Work mode: ', tr.work_mode, '. ',
|
|---|
| 28 | # 'Location: ', l.city, '. '
|
|---|
| 29 | # ) AS embedding_text
|
|---|
| 30 | #
|
|---|
| 31 | # FROM TaskRequest tr
|
|---|
| 32 | # JOIN Category c
|
|---|
| 33 | # ON c.id = tr.category_id
|
|---|
| 34 | # JOIN Location l
|
|---|
| 35 | # ON l.id = tr.location_id;
|
|---|
| 36 |
|
|---|
| 37 | read_cursor.execute("""
|
|---|
| 38 | WITH ranked_tasks AS
|
|---|
| 39 | (
|
|---|
| 40 | SELECT
|
|---|
| 41 | o.worker_id,
|
|---|
| 42 | o.task_request_id,
|
|---|
| 43 | ROW_NUMBER()
|
|---|
| 44 | OVER
|
|---|
| 45 | (
|
|---|
| 46 | PARTITION BY o.worker_id
|
|---|
| 47 | ORDER BY
|
|---|
| 48 | COALESCE(
|
|---|
| 49 | t.updated_at,
|
|---|
| 50 | t.created_at
|
|---|
| 51 | ) DESC
|
|---|
| 52 | ) AS rn
|
|---|
| 53 | FROM Task t
|
|---|
| 54 | JOIN Offer o
|
|---|
| 55 | ON o.id = t.offer_id
|
|---|
| 56 | WHERE
|
|---|
| 57 | t.status = 'COMPLETED'
|
|---|
| 58 | )
|
|---|
| 59 | SELECT DISTINCT
|
|---|
| 60 | v.task_request_id,
|
|---|
| 61 | v.embedding_text
|
|---|
| 62 | FROM ranked_tasks rt
|
|---|
| 63 | JOIN
|
|---|
| 64 | vw_task_request_embedding_text v
|
|---|
| 65 | ON v.task_request_id =
|
|---|
| 66 | rt.task_request_id
|
|---|
| 67 | LEFT JOIN
|
|---|
| 68 | task_request_embeddings te
|
|---|
| 69 | ON te.task_request_id =
|
|---|
| 70 | rt.task_request_id
|
|---|
| 71 | WHERE
|
|---|
| 72 | rt.rn <= 3
|
|---|
| 73 | AND te.task_request_id IS NULL
|
|---|
| 74 | ORDER BY
|
|---|
| 75 | v.task_request_id
|
|---|
| 76 | """)
|
|---|
| 77 |
|
|---|
| 78 | # Variant 2 - we used the same script for embedding the open taskrequests which we will reccomend with this change in the script.
|
|---|
| 79 | # SELECT
|
|---|
| 80 | # v.task_request_id,
|
|---|
| 81 | # v.embedding_text
|
|---|
| 82 | # FROM vw_task_request_embedding_text v
|
|---|
| 83 | #
|
|---|
| 84 | # JOIN TaskRequest tr
|
|---|
| 85 | # ON tr.id = v.task_request_id
|
|---|
| 86 | #
|
|---|
| 87 | # LEFT JOIN task_request_embeddings te
|
|---|
| 88 | # ON te.task_request_id = v.task_request_id
|
|---|
| 89 | #
|
|---|
| 90 | # WHERE
|
|---|
| 91 | # tr.status = 'OPEN'
|
|---|
| 92 | # AND te.task_request_id IS NULL
|
|---|
| 93 | #
|
|---|
| 94 | # ORDER BY
|
|---|
| 95 | # v.task_request_id;
|
|---|
| 96 |
|
|---|
| 97 | rows = read_cursor.fetchall()
|
|---|
| 98 | total_rows = len(rows)
|
|---|
| 99 | print(f"Found {total_rows} task requests without embeddings.")
|
|---|
| 100 | print()
|
|---|
| 101 |
|
|---|
| 102 | processed = 0
|
|---|
| 103 | batch_counter = 0
|
|---|
| 104 | commit_counter = 0
|
|---|
| 105 |
|
|---|
| 106 | for start in range(0, total_rows, BATCH_SIZE):
|
|---|
| 107 | batch = rows[ start: start + BATCH_SIZE ]
|
|---|
| 108 | ids = []
|
|---|
| 109 | texts = []
|
|---|
| 110 |
|
|---|
| 111 | for (task_request_id, embedding_text) in batch:
|
|---|
| 112 | text = str(embedding_text or "").strip()
|
|---|
| 113 | if not text:
|
|---|
| 114 | continue
|
|---|
| 115 | ids.append( task_request_id)
|
|---|
| 116 | texts.append(text)
|
|---|
| 117 | if len(texts) == 0:
|
|---|
| 118 | continue
|
|---|
| 119 |
|
|---|
| 120 | batch_counter += 1
|
|---|
| 121 |
|
|---|
| 122 | vectors = model.encode(
|
|---|
| 123 | texts,
|
|---|
| 124 | convert_to_numpy=True,
|
|---|
| 125 | normalize_embeddings=True,
|
|---|
| 126 | show_progress_bar=False
|
|---|
| 127 | )
|
|---|
| 128 |
|
|---|
| 129 | insert_data = []
|
|---|
| 130 | for (task_request_id, vector) in zip(ids, vectors):
|
|---|
| 131 | vector_string = ("[" + ",".join(map(str,vector.tolist())) + "]" )
|
|---|
| 132 | insert_data.append((task_request_id, vector_string))
|
|---|
| 133 |
|
|---|
| 134 | execute_batch(
|
|---|
| 135 | write_cursor,
|
|---|
| 136 | """
|
|---|
| 137 | INSERT INTO
|
|---|
| 138 | task_request_embeddings
|
|---|
| 139 | (
|
|---|
| 140 | task_request_id,
|
|---|
| 141 | embedding,
|
|---|
| 142 | embedded_at
|
|---|
| 143 | )
|
|---|
| 144 | VALUES
|
|---|
| 145 | (
|
|---|
| 146 | %s,
|
|---|
| 147 | %s::vector,
|
|---|
| 148 | CURRENT_TIMESTAMP
|
|---|
| 149 | )
|
|---|
| 150 | ON CONFLICT
|
|---|
| 151 | (
|
|---|
| 152 | task_request_id
|
|---|
| 153 | )
|
|---|
| 154 | DO UPDATE SET
|
|---|
| 155 | embedding =
|
|---|
| 156 | EXCLUDED.embedding,
|
|---|
| 157 | embedded_at =
|
|---|
| 158 | CURRENT_TIMESTAMP
|
|---|
| 159 | """,
|
|---|
| 160 | insert_data,
|
|---|
| 161 | page_size=1000
|
|---|
| 162 | )
|
|---|
| 163 | commit_counter += 1
|
|---|
| 164 | if commit_counter >= COMMIT_EVERY:
|
|---|
| 165 | conn.commit()
|
|---|
| 166 | commit_counter = 0
|
|---|
| 167 | if commit_counter > 0:
|
|---|
| 168 | conn.commit()
|
|---|
| 169 |
|
|---|
| 170 | print( f"Total embeddings: "f"{processed}")
|
|---|
| 171 |
|
|---|
| 172 | read_cursor.close()
|
|---|
| 173 | write_cursor.close()
|
|---|
| 174 | conn.close()
|
|---|