asoria HF staff commited on
Commit
a5eff40
1 Parent(s): b5ecaeb

Progress bar by task

Browse files
Files changed (1) hide show
  1. app.py +194 -152
app.py CHANGED
@@ -178,9 +178,9 @@ def generate_topics(dataset, config, split, column, plot_type):
178
  topics_info, topic_plot = None, None
179
  full_processing = split_rows <= MAX_ROWS
180
  message = (
181
- f"⚙️ Processing full dataset: 0 of ({split_rows} rows)"
182
  if full_processing
183
- else f"⚙️ Processing partial dataset 0 of ({limit} rows)"
184
  )
185
  sub_title = (
186
  f"Data map for the entire dataset ({limit} rows) using the column '{column}'"
@@ -191,48 +191,140 @@ def generate_topics(dataset, config, split, column, plot_type):
191
  gr.Accordion(open=False),
192
  gr.DataFrame(value=[], interactive=False, visible=True),
193
  gr.Plot(value=None, visible=True),
194
- gr.Label({message: rows_processed / limit}, visible=True),
195
  "",
196
  )
197
 
198
- while offset < limit:
199
- docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
200
- if not docs:
201
- break
 
202
 
203
- logging.info(
204
- f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs"
205
- )
206
 
207
- embeddings = calculate_embeddings(docs)
208
- new_model = fit_model(docs, embeddings, n_neighbors, n_components)
209
 
210
- if base_model is None:
211
- base_model = new_model
212
- logging.info(
213
- f"The following topics are newly found: {base_model.topic_labels_}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
- else:
216
- updated_model = BERTopic.merge_models([base_model, new_model])
217
- nr_new_topics = len(set(updated_model.topics_)) - len(
218
- set(base_model.topics_)
 
 
 
219
  )
220
- new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
221
- logging.info(f"The following topics are newly found: {new_topics}")
222
- base_model = updated_model
223
 
224
- reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
225
- reduced_embeddings_list.append(reduced_embeddings)
 
 
 
 
 
226
 
227
- all_docs.extend(docs)
228
- reduced_embeddings_array = np.vstack(reduced_embeddings_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- topics_info = base_model.get_topic_info()
231
  all_topics = base_model.topics_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  topic_plot = (
233
  base_model.visualize_document_datamap(
234
  docs=all_docs,
235
  topics=all_topics,
 
236
  reduced_embeddings=reduced_embeddings_array,
237
  title="",
238
  sub_title=sub_title,
@@ -258,137 +350,87 @@ def generate_topics(dataset, config, split, column, plot_type):
258
  title="",
259
  )
260
  )
261
- rows_processed += len(docs)
262
- progress = min(rows_processed / limit, 1.0)
263
- logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
264
- message = (
265
- f"⚙️ Processing full dataset: {rows_processed} of {limit}"
266
- if full_processing
267
- else f"⚙️ Processing partial dataset: {rows_processed} of {limit} rows"
268
- )
269
-
270
  yield (
271
  gr.Accordion(open=False),
272
  topics_info,
273
  topic_plot,
274
- gr.Label({message: progress}, visible=True),
 
 
 
 
 
 
 
275
  "",
276
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- offset += CHUNK_SIZE
279
- del docs, embeddings, new_model, reduced_embeddings
280
- logging.info("Finished processing all data")
281
-
282
- dataset_clear_name = dataset.replace("/", "-")
283
- plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
284
- if plot_type == "DataMapPlot":
285
- topic_plot.savefig(plot_png, format="png", dpi=300)
286
- else:
287
- topic_plot.write_image(plot_png)
288
-
289
- all_topics = base_model.topics_
290
- topic_info = base_model.get_topic_info()
291
 
292
- new_topics_by_text_generation = {}
293
- for _, row in topic_info.iterrows():
294
- logging.info(
295
- f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
 
 
 
 
 
 
 
 
 
296
  )
297
- prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
298
- logging.info(prompt)
299
- topic_description = generator(prompt)
300
- logging.info(topic_description)
301
- new_topics_by_text_generation[row["Topic"]] = topic_description[0][
302
- "generated_text"
303
- ].replace(prompt, "")
304
- base_model.set_topic_labels(new_topics_by_text_generation)
305
-
306
- topics_info = base_model.get_topic_info()
307
-
308
- topic_plot = (
309
- base_model.visualize_document_datamap(
310
- docs=all_docs,
311
- topics=all_topics,
312
- custom_labels=True,
313
- reduced_embeddings=reduced_embeddings_array,
314
- title="",
315
- sub_title=sub_title,
316
- width=800,
317
- height=700,
318
- arrowprops={
319
- "arrowstyle": "wedge,tail_width=0.5",
320
- "connectionstyle": "arc3,rad=0.05",
321
- "linewidth": 0,
322
- "fc": "#33333377",
323
- },
324
- dynamic_label_size=True,
325
- # label_wrap_width=12,
326
- label_over_points=True,
327
- max_font_size=36,
328
- min_font_size=4,
329
  )
330
- if plot_type == "DataMapPlot"
331
- else base_model.visualize_documents(
332
- docs=all_docs,
333
- reduced_embeddings=reduced_embeddings_array,
334
- custom_labels=True,
335
- title="",
 
 
336
  )
337
- )
338
- custom_labels = base_model.custom_labels_
339
- topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
340
-
341
- interactive_plot = datamapplot.create_interactive_plot(
342
- reduced_embeddings_array,
343
- topic_names_array,
344
- hover_text=all_docs,
345
- title=dataset,
346
- sub_title=sub_title.replace(
347
- "dataset",
348
- f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
349
- ),
350
- enable_search=True,
351
- # TODO: Export data to .arrow and also serve it
352
- inline_data=True,
353
- # offline_data_prefix=dataset_clear_name,
354
- initial_zoom_fraction=0.8,
355
- )
356
- html_content = str(interactive_plot)
357
- html_file_path = f"{dataset_clear_name}.html"
358
- with open(html_file_path, "w", encoding="utf-8") as html_file:
359
- html_file.write(html_content)
360
-
361
- repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
362
-
363
- space_id = create_space_with_content(
364
- api=api,
365
- repo_id=repo_id,
366
- dataset_id=dataset,
367
- html_file_path=html_file_path,
368
- plot_file_path=plot_png,
369
- space_card=SPACE_REPO_CARD_CONTENT,
370
- token=HF_TOKEN,
371
- )
372
-
373
- space_link = f"https://huggingface.co/spaces/{space_id}"
374
- yield (
375
- gr.Accordion(open=False),
376
- topics_info,
377
- topic_plot,
378
- gr.Label(
379
- {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
380
- ),
381
- f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
382
- )
383
- del reduce_umap_model, all_docs, reduced_embeddings_list
384
- del (
385
- base_model,
386
- all_topics,
387
- topic_info,
388
- topic_names_array,
389
- interactive_plot,
390
- )
391
- cuda.empty_cache()
392
 
393
 
394
  with gr.Blocks() as demo:
@@ -437,11 +479,11 @@ with gr.Blocks() as demo:
437
  generate_button = gr.Button("Generate Topics", variant="primary")
438
 
439
  gr.Markdown("## Data map")
440
- full_topics_generation_label = gr.Label(visible=False, show_label=False)
441
  open_space_label = gr.Markdown()
442
  topics_plot = gr.Plot()
443
- with gr.Accordion("Topics Info", open=False):
444
- topics_df = gr.DataFrame(interactive=False, visible=True)
445
  gr.HTML(
446
  f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
447
  )
@@ -463,7 +505,7 @@ with gr.Blocks() as demo:
463
  data_details_accordion,
464
  topics_df,
465
  topics_plot,
466
- full_topics_generation_label,
467
  open_space_label,
468
  ],
469
  )
 
178
  topics_info, topic_plot = None, None
179
  full_processing = split_rows <= MAX_ROWS
180
  message = (
181
+ f"Processing topics for full dataset: 0 of ({split_rows} rows)"
182
  if full_processing
183
+ else f"Processing topics for partial dataset 0 of ({limit} rows)"
184
  )
185
  sub_title = (
186
  f"Data map for the entire dataset ({limit} rows) using the column '{column}'"
 
191
  gr.Accordion(open=False),
192
  gr.DataFrame(value=[], interactive=False, visible=True),
193
  gr.Plot(value=None, visible=True),
194
+ gr.Label({"⏳ " + message: 0.0}, visible=True),
195
  "",
196
  )
197
 
198
+ try:
199
+ while offset < limit:
200
+ docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
201
+ if not docs:
202
+ break
203
 
204
+ logging.info(
205
+ f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs"
206
+ )
207
 
208
+ embeddings = calculate_embeddings(docs)
209
+ new_model = fit_model(docs, embeddings, n_neighbors, n_components)
210
 
211
+ if base_model is None:
212
+ base_model = new_model
213
+ logging.info(
214
+ f"The following topics are newly found: {base_model.topic_labels_}"
215
+ )
216
+ else:
217
+ updated_model = BERTopic.merge_models([base_model, new_model])
218
+ nr_new_topics = len(set(updated_model.topics_)) - len(
219
+ set(base_model.topics_)
220
+ )
221
+ new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
222
+ logging.info(f"The following topics are newly found: {new_topics}")
223
+ base_model = updated_model
224
+
225
+ reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
226
+ reduced_embeddings_list.append(reduced_embeddings)
227
+
228
+ all_docs.extend(docs)
229
+ reduced_embeddings_array = np.vstack(reduced_embeddings_list)
230
+
231
+ topics_info = base_model.get_topic_info()
232
+ all_topics = base_model.topics_
233
+ topic_plot = (
234
+ base_model.visualize_document_datamap(
235
+ docs=all_docs,
236
+ topics=all_topics,
237
+ reduced_embeddings=reduced_embeddings_array,
238
+ title="",
239
+ sub_title=sub_title,
240
+ width=800,
241
+ height=700,
242
+ arrowprops={
243
+ "arrowstyle": "wedge,tail_width=0.5",
244
+ "connectionstyle": "arc3,rad=0.05",
245
+ "linewidth": 0,
246
+ "fc": "#33333377",
247
+ },
248
+ dynamic_label_size=True,
249
+ # label_wrap_width=12,
250
+ label_over_points=True,
251
+ max_font_size=36,
252
+ min_font_size=4,
253
+ )
254
+ if plot_type == "DataMapPlot"
255
+ else base_model.visualize_documents(
256
+ docs=all_docs,
257
+ reduced_embeddings=reduced_embeddings_array,
258
+ custom_labels=True,
259
+ title="",
260
+ )
261
  )
262
+ rows_processed += len(docs)
263
+ progress = min(rows_processed / limit, 1.0)
264
+ logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
265
+ message = (
266
+ f"Processing topics for full dataset: {rows_processed} of {limit}"
267
+ if full_processing
268
+ else f"Processing topics for partial dataset: {rows_processed} of {limit} rows"
269
  )
 
 
 
270
 
271
+ yield (
272
+ gr.Accordion(open=False),
273
+ topics_info,
274
+ topic_plot,
275
+ gr.Label({"⏳ " + message: progress}, visible=True),
276
+ "",
277
+ )
278
 
279
+ offset += CHUNK_SIZE
280
+ del docs, embeddings, new_model, reduced_embeddings
281
+ logging.info("Finished processing topic modeling data")
282
+
283
+ yield (
284
+ gr.Accordion(open=False),
285
+ topics_info,
286
+ topic_plot,
287
+ gr.Label(
288
+ {
289
+ "✅ " + message: 1.0,
290
+ f"⏳ Generating topic names with {model_id}": 0.0,
291
+ },
292
+ visible=True,
293
+ ),
294
+ "",
295
+ )
296
+
297
+ dataset_clear_name = dataset.replace("/", "-")
298
+ plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
299
+ if plot_type == "DataMapPlot":
300
+ topic_plot.savefig(plot_png, format="png", dpi=300)
301
+ else:
302
+ topic_plot.write_image(plot_png)
303
 
 
304
  all_topics = base_model.topics_
305
+ topics_info = base_model.get_topic_info()
306
+
307
+ new_topics_by_text_generation = {}
308
+ for _, row in topics_info.iterrows():
309
+ logging.info(
310
+ f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
311
+ )
312
+ prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
313
+ logging.info(prompt)
314
+ topic_description = generator(prompt)
315
+ logging.info(topic_description)
316
+ new_topics_by_text_generation[row["Topic"]] = topic_description[0][
317
+ "generated_text"
318
+ ].replace(prompt, "")
319
+ base_model.set_topic_labels(new_topics_by_text_generation)
320
+
321
+ topics_info = base_model.get_topic_info()
322
+
323
  topic_plot = (
324
  base_model.visualize_document_datamap(
325
  docs=all_docs,
326
  topics=all_topics,
327
+ custom_labels=True,
328
  reduced_embeddings=reduced_embeddings_array,
329
  title="",
330
  sub_title=sub_title,
 
350
  title="",
351
  )
352
  )
353
+ custom_labels = base_model.custom_labels_
354
+ topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
 
 
 
 
 
 
 
355
  yield (
356
  gr.Accordion(open=False),
357
  topics_info,
358
  topic_plot,
359
+ gr.Label(
360
+ {
361
+ "✅ " + message: 1.0,
362
+ f"✅ Generating topic names with {model_id}": 1.0,
363
+ "⏳ Creating Interactive Space": 0.0,
364
+ },
365
+ visible=True,
366
+ ),
367
  "",
368
  )
369
+ interactive_plot = datamapplot.create_interactive_plot(
370
+ reduced_embeddings_array,
371
+ topic_names_array,
372
+ hover_text=all_docs,
373
+ title=dataset,
374
+ sub_title=sub_title.replace(
375
+ "dataset",
376
+ f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
377
+ ),
378
+ enable_search=True,
379
+ # TODO: Export data to .arrow and also serve it
380
+ inline_data=True,
381
+ # offline_data_prefix=dataset_clear_name,
382
+ initial_zoom_fraction=0.8,
383
+ )
384
+ html_content = str(interactive_plot)
385
+ html_file_path = f"{dataset_clear_name}.html"
386
+ with open(html_file_path, "w", encoding="utf-8") as html_file:
387
+ html_file.write(html_content)
388
+
389
+ repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
390
+
391
+ space_id = create_space_with_content(
392
+ api=api,
393
+ repo_id=repo_id,
394
+ dataset_id=dataset,
395
+ html_file_path=html_file_path,
396
+ plot_file_path=plot_png,
397
+ space_card=SPACE_REPO_CARD_CONTENT,
398
+ token=HF_TOKEN,
399
+ )
400
 
401
+ space_link = f"https://huggingface.co/spaces/{space_id}"
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
+ yield (
404
+ gr.Accordion(open=False),
405
+ topics_info,
406
+ topic_plot,
407
+ gr.Label(
408
+ {
409
+ "✅ " + message: 1.0,
410
+ f"✅ Generating topic names with {model_id}": 1.0,
411
+ "✅ Creating Interactive Space": 1.0,
412
+ },
413
+ visible=True,
414
+ ),
415
+ f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
416
  )
417
+ del reduce_umap_model, all_docs, reduced_embeddings_list
418
+ del (
419
+ base_model,
420
+ all_topics,
421
+ topic_info,
422
+ topic_names_array,
423
+ interactive_plot,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  )
425
+ cuda.empty_cache()
426
+ except Exception as error:
427
+ return (
428
+ gr.Accordion(open=True),
429
+ gr.DataFrame(value=[], interactive=False, visible=True),
430
+ gr.Plot(value=None, visible=True),
431
+ gr.Label({f"❌ Error: {error}": 0.0}, visible=True),
432
+ "",
433
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
 
436
  with gr.Blocks() as demo:
 
479
  generate_button = gr.Button("Generate Topics", variant="primary")
480
 
481
  gr.Markdown("## Data map")
482
+ progress_label = gr.Label(visible=False, show_label=False)
483
  open_space_label = gr.Markdown()
484
  topics_plot = gr.Plot()
485
+ # with gr.Accordion("Topics Info", open=False):
486
+ topics_df = gr.DataFrame(interactive=False, visible=True)
487
  gr.HTML(
488
  f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
489
  )
 
505
  data_details_accordion,
506
  topics_df,
507
  topics_plot,
508
+ progress_label,
509
  open_space_label,
510
  ],
511
  )