Spaces:
Running
on
T4
Running
on
T4
Refactor
Browse files- app.py +22 -20
- prompts.py +5 -3
app.py
CHANGED
@@ -7,7 +7,6 @@ from bertopic import BERTopic
|
|
7 |
import gradio as gr
|
8 |
from bertopic.representation import (
|
9 |
KeyBERTInspired,
|
10 |
-
MaximalMarginalRelevance,
|
11 |
TextGeneration,
|
12 |
)
|
13 |
from umap import UMAP
|
@@ -19,8 +18,7 @@ from transformers import (
|
|
19 |
AutoModelForCausalLM,
|
20 |
pipeline,
|
21 |
)
|
22 |
-
from prompts import
|
23 |
-
from umap import UMAP
|
24 |
from hdbscan import HDBSCAN
|
25 |
from sklearn.feature_extraction.text import CountVectorizer
|
26 |
|
@@ -36,7 +34,6 @@ logging.basicConfig(
|
|
36 |
session = requests.Session()
|
37 |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
38 |
keybert = KeyBERTInspired()
|
39 |
-
mmr = MaximalMarginalRelevance(diversity=0.3)
|
40 |
vectorizer_model = CountVectorizer(stop_words="english")
|
41 |
|
42 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
@@ -52,7 +49,6 @@ bnb_config = BitsAndBytesConfig(
|
|
52 |
|
53 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
54 |
|
55 |
-
# Llama 2 Model
|
56 |
model = AutoModelForCausalLM.from_pretrained(
|
57 |
model_id,
|
58 |
trust_remote_code=True,
|
@@ -68,13 +64,11 @@ generator = pipeline(
|
|
68 |
max_new_tokens=500,
|
69 |
repetition_penalty=1.1,
|
70 |
)
|
71 |
-
prompt = system_prompt + example_prompt + main_prompt
|
72 |
|
73 |
-
llama2 = TextGeneration(generator, prompt=
|
74 |
representation_model = {
|
75 |
"KeyBERT": keybert,
|
76 |
"Llama2": llama2,
|
77 |
-
# "MMR": mmr,
|
78 |
}
|
79 |
|
80 |
umap_model = UMAP(
|
@@ -132,9 +126,9 @@ def fit_model(base_model, docs, embeddings):
|
|
132 |
verbose=True,
|
133 |
min_topic_size=15,
|
134 |
)
|
135 |
-
logging.
|
136 |
new_model.fit(docs, embeddings)
|
137 |
-
logging.
|
138 |
|
139 |
if base_model is None:
|
140 |
return new_model, new_model
|
@@ -157,35 +151,43 @@ def generate_topics(dataset, config, split, column, nested_column):
|
|
157 |
offset = 0
|
158 |
base_model = None
|
159 |
all_docs = []
|
160 |
-
|
161 |
-
|
|
|
162 |
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
|
|
|
|
|
|
163 |
logging.info(
|
164 |
-
f"
|
165 |
)
|
|
|
166 |
embeddings = calculate_embeddings(docs)
|
167 |
-
offset = offset + chunk_size
|
168 |
-
if not docs or offset >= limit:
|
169 |
-
break
|
170 |
base_model, _ = fit_model(base_model, docs, embeddings)
|
171 |
llama2_labels = [
|
172 |
label[0][0].split("\n")[0]
|
173 |
for label in base_model.get_topics(full=True)["Llama2"].values()
|
174 |
]
|
175 |
-
logging.info(f"Topics: {llama2_labels}")
|
176 |
base_model.set_topic_labels(llama2_labels)
|
177 |
|
178 |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
|
|
179 |
|
180 |
all_docs.extend(docs)
|
181 |
-
|
182 |
topics_info = base_model.get_topic_info()
|
183 |
topic_plot = base_model.visualize_documents(
|
184 |
-
all_docs,
|
|
|
|
|
185 |
)
|
186 |
-
|
|
|
|
|
187 |
yield topics_info, topic_plot
|
188 |
|
|
|
|
|
189 |
logging.info("Finished processing all data")
|
190 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
191 |
|
|
|
7 |
import gradio as gr
|
8 |
from bertopic.representation import (
|
9 |
KeyBERTInspired,
|
|
|
10 |
TextGeneration,
|
11 |
)
|
12 |
from umap import UMAP
|
|
|
18 |
AutoModelForCausalLM,
|
19 |
pipeline,
|
20 |
)
|
21 |
+
from prompts import REPRESENTATION_PROMPT
|
|
|
22 |
from hdbscan import HDBSCAN
|
23 |
from sklearn.feature_extraction.text import CountVectorizer
|
24 |
|
|
|
34 |
session = requests.Session()
|
35 |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
36 |
keybert = KeyBERTInspired()
|
|
|
37 |
vectorizer_model = CountVectorizer(stop_words="english")
|
38 |
|
39 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
|
|
49 |
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
51 |
|
|
|
52 |
model = AutoModelForCausalLM.from_pretrained(
|
53 |
model_id,
|
54 |
trust_remote_code=True,
|
|
|
64 |
max_new_tokens=500,
|
65 |
repetition_penalty=1.1,
|
66 |
)
|
|
|
67 |
|
68 |
+
llama2 = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
|
69 |
representation_model = {
|
70 |
"KeyBERT": keybert,
|
71 |
"Llama2": llama2,
|
|
|
72 |
}
|
73 |
|
74 |
umap_model = UMAP(
|
|
|
126 |
verbose=True,
|
127 |
min_topic_size=15,
|
128 |
)
|
129 |
+
logging.debug("Fitting new model")
|
130 |
new_model.fit(docs, embeddings)
|
131 |
+
logging.debug("End fitting new model")
|
132 |
|
133 |
if base_model is None:
|
134 |
return new_model, new_model
|
|
|
151 |
offset = 0
|
152 |
base_model = None
|
153 |
all_docs = []
|
154 |
+
reduced_embeddings_list = []
|
155 |
+
|
156 |
+
while offset < limit:
|
157 |
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
158 |
+
if not docs:
|
159 |
+
break
|
160 |
+
|
161 |
logging.info(
|
162 |
+
f"----> Processing chunk: {offset=} {chunk_size=} with {len(docs)} docs"
|
163 |
)
|
164 |
+
|
165 |
embeddings = calculate_embeddings(docs)
|
|
|
|
|
|
|
166 |
base_model, _ = fit_model(base_model, docs, embeddings)
|
167 |
llama2_labels = [
|
168 |
label[0][0].split("\n")[0]
|
169 |
for label in base_model.get_topics(full=True)["Llama2"].values()
|
170 |
]
|
|
|
171 |
base_model.set_topic_labels(llama2_labels)
|
172 |
|
173 |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
|
174 |
+
reduced_embeddings_list.append(reduced_embeddings)
|
175 |
|
176 |
all_docs.extend(docs)
|
177 |
+
|
178 |
topics_info = base_model.get_topic_info()
|
179 |
topic_plot = base_model.visualize_documents(
|
180 |
+
all_docs,
|
181 |
+
reduced_embeddings=np.vstack(reduced_embeddings_list),
|
182 |
+
custom_labels=True,
|
183 |
)
|
184 |
+
|
185 |
+
logging.info(f"Topics: {llama2_labels}")
|
186 |
+
|
187 |
yield topics_info, topic_plot
|
188 |
|
189 |
+
offset += chunk_size
|
190 |
+
|
191 |
logging.info("Finished processing all data")
|
192 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
193 |
|
prompts.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
|
2 |
<s>[INST] <<SYS>>
|
3 |
You are a helpful, respectful and honest assistant for labeling topics.
|
4 |
<</SYS>>
|
5 |
"""
|
6 |
|
7 |
-
|
8 |
I have a topic that contains the following documents:
|
9 |
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
10 |
- Meat, but especially beef, is the word food in terms of emissions.
|
@@ -17,7 +17,7 @@ Based on the information about the topic above, please create a short label of t
|
|
17 |
[/INST] Environmental impacts of eating meat
|
18 |
"""
|
19 |
|
20 |
-
|
21 |
[INST]
|
22 |
I have a topic that contains the following documents:
|
23 |
[DOCUMENTS]
|
@@ -27,3 +27,5 @@ The topic is described by the following keywords: '[KEYWORDS]'.
|
|
27 |
Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
|
28 |
[/INST]
|
29 |
"""
|
|
|
|
|
|
1 |
+
SYSTEM_PROMPT = """
|
2 |
<s>[INST] <<SYS>>
|
3 |
You are a helpful, respectful and honest assistant for labeling topics.
|
4 |
<</SYS>>
|
5 |
"""
|
6 |
|
7 |
+
EXAMPLE_PROMPT = """
|
8 |
I have a topic that contains the following documents:
|
9 |
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
10 |
- Meat, but especially beef, is the word food in terms of emissions.
|
|
|
17 |
[/INST] Environmental impacts of eating meat
|
18 |
"""
|
19 |
|
20 |
+
MAIN_PROMPT = """
|
21 |
[INST]
|
22 |
I have a topic that contains the following documents:
|
23 |
[DOCUMENTS]
|
|
|
27 |
Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
|
28 |
[/INST]
|
29 |
"""
|
30 |
+
|
31 |
+
REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
|