asoria HF staff commited on
Commit
abbebb7
1 Parent(s): fd054e7

Adding progress bar

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -26,7 +26,7 @@ from sentence_transformers import SentenceTransformer
26
  from dotenv import load_dotenv
27
  import os
28
 
29
- import spaces
30
  import gradio as gr
31
 
32
 
@@ -132,13 +132,13 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
132
  return df[column].tolist()
133
 
134
 
135
- @spaces.GPU
136
  # TODO: Modify batch size to reduce memory consumption during embedding calculation, which value is better?
137
  def calculate_embeddings(docs):
138
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
139
 
140
 
141
- @spaces.GPU
142
  def fit_model(docs, embeddings):
143
  global global_topic_model
144
 
@@ -177,6 +177,11 @@ def generate_topics(dataset, config, split, column, nested_column):
177
  all_docs = []
178
  reduced_embeddings_list = []
179
  topics_info, topic_plot = None, None
 
 
 
 
 
180
  while offset < limit:
181
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
182
  if not docs:
@@ -220,14 +225,23 @@ def generate_topics(dataset, config, split, column, nested_column):
220
  )
221
 
222
  logging.info(f"Topics: {repr_model_topics}")
 
223
 
224
- yield topics_info, topic_plot
 
 
 
 
225
 
226
  offset += chunk_size
227
 
228
  logging.info("Finished processing all data")
229
  cuda.empty_cache() # Clear cache at the end of each chunk
230
- return topics_info, topic_plot
 
 
 
 
231
 
232
 
233
  with gr.Blocks() as demo:
@@ -267,6 +281,7 @@ with gr.Blocks() as demo:
267
  generate_button = gr.Button("Generate Topics", variant="primary")
268
 
269
  gr.Markdown("## Datamap")
 
270
  topics_plot = gr.Plot()
271
  with gr.Accordion("Topics Info", open=False):
272
  topics_df = gr.DataFrame(interactive=False, visible=True)
@@ -279,7 +294,7 @@ with gr.Blocks() as demo:
279
  text_column_dropdown,
280
  nested_text_column_dropdown,
281
  ],
282
- outputs=[topics_df, topics_plot],
283
  )
284
 
285
  def _resolve_dataset_selection(
 
26
  from dotenv import load_dotenv
27
  import os
28
 
29
+ # import spaces
30
  import gradio as gr
31
 
32
 
 
132
  return df[column].tolist()
133
 
134
 
135
+ # @spaces.GPU
136
  # TODO: Modify batch size to reduce memory consumption during embedding calculation, which value is better?
137
  def calculate_embeddings(docs):
138
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
139
 
140
 
141
+ # @spaces.GPU
142
  def fit_model(docs, embeddings):
143
  global global_topic_model
144
 
 
177
  all_docs = []
178
  reduced_embeddings_list = []
179
  topics_info, topic_plot = None, None
180
+ yield (
181
+ gr.DataFrame(interactive=False, visible=True),
182
+ gr.Plot(visible=True),
183
+ gr.Label({f"⚙️ Generating topics {dataset}": 0.0}, visible=True),
184
+ )
185
  while offset < limit:
186
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
187
  if not docs:
 
225
  )
226
 
227
  logging.info(f"Topics: {repr_model_topics}")
228
+ progress = min(offset / limit, 1.0)
229
 
230
+ yield (
231
+ topics_info,
232
+ topic_plot,
233
+ gr.Label({f"⚙️ Generating topics {dataset}": progress}, visible=True),
234
+ )
235
 
236
  offset += chunk_size
237
 
238
  logging.info("Finished processing all data")
239
  cuda.empty_cache() # Clear cache at the end of each chunk
240
+ return (
241
+ topics_info,
242
+ topic_plot,
243
+ gr.Label({f"⚙️ Generating topics {dataset}": 1.0}, visible=True),
244
+ )
245
 
246
 
247
  with gr.Blocks() as demo:
 
281
  generate_button = gr.Button("Generate Topics", variant="primary")
282
 
283
  gr.Markdown("## Datamap")
284
+ full_topics_generation_label = gr.Label(visible=False, show_label=False)
285
  topics_plot = gr.Plot()
286
  with gr.Accordion("Topics Info", open=False):
287
  topics_df = gr.DataFrame(interactive=False, visible=True)
 
294
  text_column_dropdown,
295
  nested_text_column_dropdown,
296
  ],
297
+ outputs=[topics_df, topics_plot, full_topics_generation_label],
298
  )
299
 
300
  def _resolve_dataset_selection(