Spaces:
Sleeping
Sleeping
Apply text generation layer at the end only
Browse files
app.py
CHANGED
@@ -44,7 +44,6 @@ DATASETS_TOPICS_ORGANIZATION = os.getenv(
|
|
44 |
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
|
45 |
)
|
46 |
USE_CUML = int(os.getenv("USE_CUML", "1"))
|
47 |
-
USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
|
48 |
|
49 |
# Use cuml lib only if configured
|
50 |
if USE_CUML:
|
@@ -60,43 +59,39 @@ logging.basicConfig(
|
|
60 |
)
|
61 |
|
62 |
api = HfApi(token=HF_TOKEN)
|
63 |
-
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
64 |
-
|
65 |
-
# Representation model
|
66 |
-
if USE_LLM_TEXT_GENERATION:
|
67 |
-
bnb_config = BitsAndBytesConfig(
|
68 |
-
load_in_4bit=True,
|
69 |
-
bnb_4bit_quant_type="nf4",
|
70 |
-
bnb_4bit_use_double_quant=True,
|
71 |
-
bnb_4bit_compute_dtype=bfloat16,
|
72 |
-
)
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
|
|
|
95 |
vectorizer_model = CountVectorizer(stop_words="english")
|
|
|
96 |
|
97 |
|
98 |
def calculate_embeddings(docs):
|
99 |
-
return
|
100 |
|
101 |
|
102 |
def calculate_n_neighbors_and_components(n_rows):
|
@@ -126,7 +121,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
|
|
126 |
new_model = BERTopic(
|
127 |
language="english",
|
128 |
# Sub-models
|
129 |
-
embedding_model=
|
130 |
umap_model=umap_model, # Step 2 - UMAP model
|
131 |
hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
|
132 |
vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
|
@@ -294,13 +289,55 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
294 |
all_topics = base_model.topics_
|
295 |
topic_info = base_model.get_topic_info()
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
)
|
|
|
|
|
|
|
304 |
interactive_plot = datamapplot.create_interactive_plot(
|
305 |
reduced_embeddings_array,
|
306 |
topic_names_array,
|
@@ -348,7 +385,6 @@ def generate_topics(dataset, config, split, column, plot_type):
|
|
348 |
base_model,
|
349 |
all_topics,
|
350 |
topic_info,
|
351 |
-
topic_names,
|
352 |
topic_names_array,
|
353 |
interactive_plot,
|
354 |
)
|
|
|
44 |
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
|
45 |
)
|
46 |
USE_CUML = int(os.getenv("USE_CUML", "1"))
|
|
|
47 |
|
48 |
# Use cuml lib only if configured
|
49 |
if USE_CUML:
|
|
|
59 |
)
|
60 |
|
61 |
api = HfApi(token=HF_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
bnb_config = BitsAndBytesConfig(
|
64 |
+
load_in_4bit=True,
|
65 |
+
bnb_4bit_quant_type="nf4",
|
66 |
+
bnb_4bit_use_double_quant=True,
|
67 |
+
bnb_4bit_compute_dtype=bfloat16,
|
68 |
+
)
|
69 |
+
|
70 |
+
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
73 |
+
model_id,
|
74 |
+
trust_remote_code=True,
|
75 |
+
quantization_config=bnb_config,
|
76 |
+
device_map="auto",
|
77 |
+
)
|
78 |
+
model.eval()
|
79 |
+
generator = pipeline(
|
80 |
+
model=model,
|
81 |
+
tokenizer=tokenizer,
|
82 |
+
task="text-generation",
|
83 |
+
temperature=0.1,
|
84 |
+
max_new_tokens=500,
|
85 |
+
repetition_penalty=1.1,
|
86 |
+
)
|
87 |
|
88 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
89 |
vectorizer_model = CountVectorizer(stop_words="english")
|
90 |
+
representation_model = KeyBERTInspired()
|
91 |
|
92 |
|
93 |
def calculate_embeddings(docs):
|
94 |
+
return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
|
95 |
|
96 |
|
97 |
def calculate_n_neighbors_and_components(n_rows):
|
|
|
121 |
new_model = BERTopic(
|
122 |
language="english",
|
123 |
# Sub-models
|
124 |
+
embedding_model=embedding_model, # Step 1 - Extract embeddings
|
125 |
umap_model=umap_model, # Step 2 - UMAP model
|
126 |
hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
|
127 |
vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
|
|
|
289 |
all_topics = base_model.topics_
|
290 |
topic_info = base_model.get_topic_info()
|
291 |
|
292 |
+
new_topics_by_text_generation = {}
|
293 |
+
for _, row in topic_info.iterrows():
|
294 |
+
logging.info(
|
295 |
+
f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
|
296 |
+
)
|
297 |
+
prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
|
298 |
+
logging.info(prompt)
|
299 |
+
topic_description = generator(prompt)
|
300 |
+
logging.info(topic_description)
|
301 |
+
new_topics_by_text_generation[row["Topic"]] = topic_description[0][
|
302 |
+
"generated_text"
|
303 |
+
].replace(prompt, "")
|
304 |
+
base_model.set_topic_labels(new_topics_by_text_generation)
|
305 |
+
|
306 |
+
topics_info = base_model.get_topic_info()
|
307 |
+
|
308 |
+
topic_plot = (
|
309 |
+
base_model.visualize_document_datamap(
|
310 |
+
docs=all_docs,
|
311 |
+
topics=all_topics,
|
312 |
+
custom_labels=True,
|
313 |
+
reduced_embeddings=reduced_embeddings_array,
|
314 |
+
title="",
|
315 |
+
sub_title=sub_title,
|
316 |
+
width=800,
|
317 |
+
height=700,
|
318 |
+
arrowprops={
|
319 |
+
"arrowstyle": "wedge,tail_width=0.5",
|
320 |
+
"connectionstyle": "arc3,rad=0.05",
|
321 |
+
"linewidth": 0,
|
322 |
+
"fc": "#33333377",
|
323 |
+
},
|
324 |
+
dynamic_label_size=True,
|
325 |
+
# label_wrap_width=12,
|
326 |
+
label_over_points=True,
|
327 |
+
max_font_size=36,
|
328 |
+
min_font_size=4,
|
329 |
+
)
|
330 |
+
if plot_type == "DataMapPlot"
|
331 |
+
else base_model.visualize_documents(
|
332 |
+
docs=all_docs,
|
333 |
+
reduced_embeddings=reduced_embeddings_array,
|
334 |
+
custom_labels=True,
|
335 |
+
title="",
|
336 |
+
)
|
337 |
)
|
338 |
+
custom_labels = base_model.custom_labels_
|
339 |
+
topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
|
340 |
+
|
341 |
interactive_plot = datamapplot.create_interactive_plot(
|
342 |
reduced_embeddings_array,
|
343 |
topic_names_array,
|
|
|
385 |
base_model,
|
386 |
all_topics,
|
387 |
topic_info,
|
|
|
388 |
topic_names_array,
|
389 |
interactive_plot,
|
390 |
)
|