asoria HF staff commited on
Commit
b5ecaeb
1 Parent(s): 8712d35

Apply text generation layer at the end only

Browse files
Files changed (1) hide show
  1. app.py +76 -40
app.py CHANGED
@@ -44,7 +44,6 @@ DATASETS_TOPICS_ORGANIZATION = os.getenv(
44
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
45
  )
46
  USE_CUML = int(os.getenv("USE_CUML", "1"))
47
- USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
48
 
49
  # Use cuml lib only if configured
50
  if USE_CUML:
@@ -60,43 +59,39 @@ logging.basicConfig(
60
  )
61
 
62
  api = HfApi(token=HF_TOKEN)
63
- sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
64
-
65
- # Representation model
66
- if USE_LLM_TEXT_GENERATION:
67
- bnb_config = BitsAndBytesConfig(
68
- load_in_4bit=True,
69
- bnb_4bit_quant_type="nf4",
70
- bnb_4bit_use_double_quant=True,
71
- bnb_4bit_compute_dtype=bfloat16,
72
- )
73
 
74
- model_id = "meta-llama/Llama-2-7b-chat-hf"
75
- tokenizer = AutoTokenizer.from_pretrained(model_id)
76
- model = AutoModelForCausalLM.from_pretrained(
77
- model_id,
78
- trust_remote_code=True,
79
- quantization_config=bnb_config,
80
- device_map="auto",
81
- )
82
- model.eval()
83
- generator = pipeline(
84
- model=model,
85
- tokenizer=tokenizer,
86
- task="text-generation",
87
- temperature=0.1,
88
- max_new_tokens=500,
89
- repetition_penalty=1.1,
90
- )
91
- representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
92
- else:
93
- representation_model = KeyBERTInspired()
 
 
 
 
94
 
 
95
  vectorizer_model = CountVectorizer(stop_words="english")
 
96
 
97
 
98
  def calculate_embeddings(docs):
99
- return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
100
 
101
 
102
  def calculate_n_neighbors_and_components(n_rows):
@@ -126,7 +121,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
126
  new_model = BERTopic(
127
  language="english",
128
  # Sub-models
129
- embedding_model=sentence_model, # Step 1 - Extract embeddings
130
  umap_model=umap_model, # Step 2 - UMAP model
131
  hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
132
  vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
@@ -294,13 +289,55 @@ def generate_topics(dataset, config, split, column, plot_type):
294
  all_topics = base_model.topics_
295
  topic_info = base_model.get_topic_info()
296
 
297
- topic_names = {row["Topic"]: row["Name"] for _, row in topic_info.iterrows()}
298
- topic_names_array = np.array(
299
- [
300
- topic_names.get(topic, "No Topic").split("_")[1].strip("-")
301
- for topic in all_topics
302
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  )
 
 
 
304
  interactive_plot = datamapplot.create_interactive_plot(
305
  reduced_embeddings_array,
306
  topic_names_array,
@@ -348,7 +385,6 @@ def generate_topics(dataset, config, split, column, plot_type):
348
  base_model,
349
  all_topics,
350
  topic_info,
351
- topic_names,
352
  topic_names_array,
353
  interactive_plot,
354
  )
 
44
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
45
  )
46
  USE_CUML = int(os.getenv("USE_CUML", "1"))
 
47
 
48
  # Use cuml lib only if configured
49
  if USE_CUML:
 
59
  )
60
 
61
  api = HfApi(token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
62
 
63
+ bnb_config = BitsAndBytesConfig(
64
+ load_in_4bit=True,
65
+ bnb_4bit_quant_type="nf4",
66
+ bnb_4bit_use_double_quant=True,
67
+ bnb_4bit_compute_dtype=bfloat16,
68
+ )
69
+
70
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
71
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ model_id,
74
+ trust_remote_code=True,
75
+ quantization_config=bnb_config,
76
+ device_map="auto",
77
+ )
78
+ model.eval()
79
+ generator = pipeline(
80
+ model=model,
81
+ tokenizer=tokenizer,
82
+ task="text-generation",
83
+ temperature=0.1,
84
+ max_new_tokens=500,
85
+ repetition_penalty=1.1,
86
+ )
87
 
88
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
89
  vectorizer_model = CountVectorizer(stop_words="english")
90
+ representation_model = KeyBERTInspired()
91
 
92
 
93
  def calculate_embeddings(docs):
94
+ return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
95
 
96
 
97
  def calculate_n_neighbors_and_components(n_rows):
 
121
  new_model = BERTopic(
122
  language="english",
123
  # Sub-models
124
+ embedding_model=embedding_model, # Step 1 - Extract embeddings
125
  umap_model=umap_model, # Step 2 - UMAP model
126
  hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
127
  vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
 
289
  all_topics = base_model.topics_
290
  topic_info = base_model.get_topic_info()
291
 
292
+ new_topics_by_text_generation = {}
293
+ for _, row in topic_info.iterrows():
294
+ logging.info(
295
+ f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
296
+ )
297
+ prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
298
+ logging.info(prompt)
299
+ topic_description = generator(prompt)
300
+ logging.info(topic_description)
301
+ new_topics_by_text_generation[row["Topic"]] = topic_description[0][
302
+ "generated_text"
303
+ ].replace(prompt, "")
304
+ base_model.set_topic_labels(new_topics_by_text_generation)
305
+
306
+ topics_info = base_model.get_topic_info()
307
+
308
+ topic_plot = (
309
+ base_model.visualize_document_datamap(
310
+ docs=all_docs,
311
+ topics=all_topics,
312
+ custom_labels=True,
313
+ reduced_embeddings=reduced_embeddings_array,
314
+ title="",
315
+ sub_title=sub_title,
316
+ width=800,
317
+ height=700,
318
+ arrowprops={
319
+ "arrowstyle": "wedge,tail_width=0.5",
320
+ "connectionstyle": "arc3,rad=0.05",
321
+ "linewidth": 0,
322
+ "fc": "#33333377",
323
+ },
324
+ dynamic_label_size=True,
325
+ # label_wrap_width=12,
326
+ label_over_points=True,
327
+ max_font_size=36,
328
+ min_font_size=4,
329
+ )
330
+ if plot_type == "DataMapPlot"
331
+ else base_model.visualize_documents(
332
+ docs=all_docs,
333
+ reduced_embeddings=reduced_embeddings_array,
334
+ custom_labels=True,
335
+ title="",
336
+ )
337
  )
338
+ custom_labels = base_model.custom_labels_
339
+ topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
340
+
341
  interactive_plot = datamapplot.create_interactive_plot(
342
  reduced_embeddings_array,
343
  topic_names_array,
 
385
  base_model,
386
  all_topics,
387
  topic_info,
 
388
  topic_names_array,
389
  interactive_plot,
390
  )