Spaces:
Running
on
T4
Running
on
T4
Changing sentence transformer
Browse files- app.py +54 -29
- requirements.txt +3 -1
app.py
CHANGED
@@ -6,7 +6,11 @@ from bertopic import BERTopic
|
|
6 |
import pandas as pd
|
7 |
import gradio as gr
|
8 |
from bertopic.representation import KeyBERTInspired
|
9 |
-
import
|
|
|
|
|
|
|
|
|
10 |
|
11 |
logging.basicConfig(
|
12 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
@@ -24,7 +28,7 @@ def get_parquet_urls(dataset, config, split):
|
|
24 |
if "error" in parquet_files:
|
25 |
raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
|
26 |
parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
|
27 |
-
logging.
|
28 |
return ",".join(f"'{url}'" for url in parquet_urls)
|
29 |
|
30 |
|
@@ -34,7 +38,7 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
|
|
34 |
logging.debug(f"Dataframe: {df.head(5)}")
|
35 |
return df[column].tolist()
|
36 |
|
37 |
-
|
38 |
def generate_topics(dataset, config, split, column, nested_column):
|
39 |
logging.info(
|
40 |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
|
@@ -45,39 +49,60 @@ def generate_topics(dataset, config, split, column, nested_column):
|
|
45 |
chunk_size = 300
|
46 |
offset = 0
|
47 |
representation_model = KeyBERTInspired()
|
48 |
-
|
49 |
-
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
50 |
-
|
51 |
-
base_model = BERTopic(
|
52 |
-
|
53 |
-
)
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
while True:
|
|
|
|
|
|
|
|
|
58 |
offset = offset + chunk_size
|
59 |
if not docs or offset >= limit:
|
60 |
break
|
61 |
|
62 |
-
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
63 |
-
logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
|
64 |
-
logging.info(docs[:5])
|
65 |
-
|
66 |
new_model = BERTopic(
|
67 |
-
"english",
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
logging.info("
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
logging.info(base_model.get_topic_info())
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
82 |
|
83 |
|
|
|
6 |
import pandas as pd
|
7 |
import gradio as gr
|
8 |
from bertopic.representation import KeyBERTInspired
|
9 |
+
from umap import UMAP
|
10 |
+
|
11 |
+
# from cuml.cluster import HDBSCAN
|
12 |
+
# from cuml.manifold import UMAP
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
|
15 |
logging.basicConfig(
|
16 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
28 |
if "error" in parquet_files:
|
29 |
raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
|
30 |
parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
|
31 |
+
logging.debug(f"Parquet files: {parquet_urls}")
|
32 |
return ",".join(f"'{url}'" for url in parquet_urls)
|
33 |
|
34 |
|
|
|
38 |
logging.debug(f"Dataframe: {df.head(5)}")
|
39 |
return df[column].tolist()
|
40 |
|
41 |
+
|
42 |
def generate_topics(dataset, config, split, column, nested_column):
|
43 |
logging.info(
|
44 |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
|
|
|
49 |
chunk_size = 300
|
50 |
offset = 0
|
51 |
representation_model = KeyBERTInspired()
|
52 |
+
base_model = None
|
53 |
+
# docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
54 |
+
|
55 |
+
# base_model = BERTopic(
|
56 |
+
# "english", representation_model=representation_model, min_topic_size=15
|
57 |
+
# )
|
58 |
+
# base_model.fit_transform(docs)
|
59 |
+
|
60 |
+
# yield base_model.get_topic_info(), base_model.visualize_topics()
|
61 |
+
# Create instances of GPU-accelerated UMAP and HDBSCAN
|
62 |
+
# umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0)
|
63 |
+
# hdbscan_model = HDBSCAN(min_samples=10, gen_min_span_tree=True)
|
64 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
|
65 |
while True:
|
66 |
+
docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
|
67 |
+
logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
|
68 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
|
69 |
+
logging.info(f"Embeddings shape: {embeddings.shape}")
|
70 |
offset = offset + chunk_size
|
71 |
if not docs or offset >= limit:
|
72 |
break
|
73 |
|
|
|
|
|
|
|
|
|
74 |
new_model = BERTopic(
|
75 |
+
"english",
|
76 |
+
embedding_model=sentence_model,
|
77 |
+
representation_model=representation_model,
|
78 |
+
min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
|
79 |
+
)
|
80 |
+
logging.info("Fitting new model")
|
81 |
+
new_model.fit(docs, embeddings)
|
82 |
+
logging.info("End fitting new model")
|
83 |
+
if base_model is not None:
|
84 |
+
updated_model = BERTopic.merge_models([base_model, new_model])
|
85 |
+
nr_new_topics = len(set(updated_model.topics_)) - len(
|
86 |
+
set(base_model.topics_)
|
87 |
+
)
|
88 |
+
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
|
89 |
+
logging.info("The following topics are newly found:")
|
90 |
+
logging.info(f"{new_topics}\n")
|
91 |
+
base_model = updated_model
|
92 |
+
else:
|
93 |
+
base_model = new_model
|
94 |
logging.info(base_model.get_topic_info())
|
95 |
+
reduced_embeddings = UMAP(
|
96 |
+
n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
|
97 |
+
).fit_transform(embeddings)
|
98 |
+
logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}")
|
99 |
+
yield (
|
100 |
+
base_model.get_topic_info(),
|
101 |
+
new_model.visualize_documents(
|
102 |
+
docs, embeddings=embeddings
|
103 |
+
), # TODO: Visualize the merged models
|
104 |
+
)
|
105 |
+
logging.info("Finished processing all data")
|
106 |
return base_model.get_topic_info(), base_model.visualize_topics()
|
107 |
|
108 |
|
requirements.txt
CHANGED
@@ -4,4 +4,6 @@ umap-learn
|
|
4 |
sentence-transformers
|
5 |
datamapplot
|
6 |
bertopic
|
7 |
-
pandas
|
|
|
|
|
|
4 |
sentence-transformers
|
5 |
datamapplot
|
6 |
bertopic
|
7 |
+
pandas
|
8 |
+
torch
|
9 |
+
cuml-cu11
|