asoria's picture
asoria HF staff
Apply text generation layer at the end only
b5ecaeb
raw
history blame
19.8 kB
# import spaces
import gradio as gr
import logging
import os
import datamapplot
import numpy as np
from dotenv import load_dotenv
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from bertopic.representation import TextGeneration
from huggingface_hub import HfApi
from sklearn.feature_extraction.text import CountVectorizer
from sentence_transformers import SentenceTransformer
from torch import cuda, bfloat16
from transformers import (
BitsAndBytesConfig,
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
)
from src.hub import create_space_with_content
from src.templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT
from src.viewer_api import (
get_split_rows,
get_parquet_urls,
get_docs_from_parquet,
get_info,
)
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
MAX_ROWS = int(os.getenv("MAX_ROWS", "8_000"))
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000"))
DATASETS_TOPICS_ORGANIZATION = os.getenv(
"DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
)
USE_CUML = int(os.getenv("USE_CUML", "1"))
# Use cuml lib only if configured
if USE_CUML:
from cuml.manifold import UMAP
from cuml.cluster import HDBSCAN
else:
from umap import UMAP
from hdbscan import HDBSCAN
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
api = HfApi(token=HF_TOKEN)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16,
)
model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
quantization_config=bnb_config,
device_map="auto",
)
model.eval()
generator = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.1,
max_new_tokens=500,
repetition_penalty=1.1,
)
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
vectorizer_model = CountVectorizer(stop_words="english")
representation_model = KeyBERTInspired()
def calculate_embeddings(docs):
return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
def calculate_n_neighbors_and_components(n_rows):
n_neighbors = min(max(n_rows // 20, 15), 100)
n_components = 10 if n_rows > 1000 else 5 # Higher components for larger datasets
return n_neighbors, n_components
def fit_model(docs, embeddings, n_neighbors, n_components):
umap_model = UMAP(
n_neighbors=n_neighbors,
n_components=n_components,
min_dist=0.0,
metric="cosine",
random_state=42,
)
hdbscan_model = HDBSCAN(
min_cluster_size=max(
5, n_neighbors // 2
), # Reducing min_cluster_size for fewer outliers
metric="euclidean",
cluster_selection_method="eom",
prediction_data=True,
)
new_model = BERTopic(
language="english",
# Sub-models
embedding_model=embedding_model, # Step 1 - Extract embeddings
umap_model=umap_model, # Step 2 - UMAP model
hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
representation_model=representation_model, # Step 5 - Label topics
# Hyperparameters
top_n_words=10,
verbose=True,
min_topic_size=n_neighbors, # Coherent with n_neighbors?
)
logging.info("Fitting new model")
new_model.fit(docs, embeddings)
logging.info("End fitting new model")
return new_model
# @spaces.GPU(duration=60 * 5)
def generate_topics(dataset, config, split, column, plot_type):
logging.info(
f"Generating topics for {dataset=} {config=} {split=} {column=} {plot_type=}"
)
parquet_urls = get_parquet_urls(dataset, config, split)
split_rows = get_split_rows(dataset, config, split)
if split_rows is None or split_rows == 0:
return (
gr.Accordion(open=True),
gr.DataFrame(value=[], interactive=False, visible=True),
gr.Plot(value=None, visible=True),
gr.Label(
{"❌ Error: No data found for the selected dataset": 0.0}, visible=True
),
"",
)
logging.info(f"Split number of rows: {split_rows}")
limit = min(split_rows, MAX_ROWS)
n_neighbors, n_components = calculate_n_neighbors_and_components(limit)
reduce_umap_model = UMAP(
n_neighbors=n_neighbors,
n_components=2, # For visualization, keeping it for 2D
min_dist=0.0,
metric="cosine",
random_state=42,
)
offset = 0
rows_processed = 0
base_model = None
all_docs = []
reduced_embeddings_list = []
topics_info, topic_plot = None, None
full_processing = split_rows <= MAX_ROWS
message = (
f"⚙️ Processing full dataset: 0 of ({split_rows} rows)"
if full_processing
else f"⚙️ Processing partial dataset 0 of ({limit} rows)"
)
sub_title = (
f"Data map for the entire dataset ({limit} rows) using the column '{column}'"
if full_processing
else f"Data map for a sample of the dataset (first {limit} rows) using the column '{column}'"
)
yield (
gr.Accordion(open=False),
gr.DataFrame(value=[], interactive=False, visible=True),
gr.Plot(value=None, visible=True),
gr.Label({message: rows_processed / limit}, visible=True),
"",
)
while offset < limit:
docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
if not docs:
break
logging.info(
f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs"
)
embeddings = calculate_embeddings(docs)
new_model = fit_model(docs, embeddings, n_neighbors, n_components)
if base_model is None:
base_model = new_model
logging.info(
f"The following topics are newly found: {base_model.topic_labels_}"
)
else:
updated_model = BERTopic.merge_models([base_model, new_model])
nr_new_topics = len(set(updated_model.topics_)) - len(
set(base_model.topics_)
)
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
logging.info(f"The following topics are newly found: {new_topics}")
base_model = updated_model
reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
reduced_embeddings_list.append(reduced_embeddings)
all_docs.extend(docs)
reduced_embeddings_array = np.vstack(reduced_embeddings_list)
topics_info = base_model.get_topic_info()
all_topics = base_model.topics_
topic_plot = (
base_model.visualize_document_datamap(
docs=all_docs,
topics=all_topics,
reduced_embeddings=reduced_embeddings_array,
title="",
sub_title=sub_title,
width=800,
height=700,
arrowprops={
"arrowstyle": "wedge,tail_width=0.5",
"connectionstyle": "arc3,rad=0.05",
"linewidth": 0,
"fc": "#33333377",
},
dynamic_label_size=True,
# label_wrap_width=12,
label_over_points=True,
max_font_size=36,
min_font_size=4,
)
if plot_type == "DataMapPlot"
else base_model.visualize_documents(
docs=all_docs,
reduced_embeddings=reduced_embeddings_array,
custom_labels=True,
title="",
)
)
rows_processed += len(docs)
progress = min(rows_processed / limit, 1.0)
logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
message = (
f"⚙️ Processing full dataset: {rows_processed} of {limit}"
if full_processing
else f"⚙️ Processing partial dataset: {rows_processed} of {limit} rows"
)
yield (
gr.Accordion(open=False),
topics_info,
topic_plot,
gr.Label({message: progress}, visible=True),
"",
)
offset += CHUNK_SIZE
del docs, embeddings, new_model, reduced_embeddings
logging.info("Finished processing all data")
dataset_clear_name = dataset.replace("/", "-")
plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
if plot_type == "DataMapPlot":
topic_plot.savefig(plot_png, format="png", dpi=300)
else:
topic_plot.write_image(plot_png)
all_topics = base_model.topics_
topic_info = base_model.get_topic_info()
new_topics_by_text_generation = {}
for _, row in topic_info.iterrows():
logging.info(
f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
)
prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
logging.info(prompt)
topic_description = generator(prompt)
logging.info(topic_description)
new_topics_by_text_generation[row["Topic"]] = topic_description[0][
"generated_text"
].replace(prompt, "")
base_model.set_topic_labels(new_topics_by_text_generation)
topics_info = base_model.get_topic_info()
topic_plot = (
base_model.visualize_document_datamap(
docs=all_docs,
topics=all_topics,
custom_labels=True,
reduced_embeddings=reduced_embeddings_array,
title="",
sub_title=sub_title,
width=800,
height=700,
arrowprops={
"arrowstyle": "wedge,tail_width=0.5",
"connectionstyle": "arc3,rad=0.05",
"linewidth": 0,
"fc": "#33333377",
},
dynamic_label_size=True,
# label_wrap_width=12,
label_over_points=True,
max_font_size=36,
min_font_size=4,
)
if plot_type == "DataMapPlot"
else base_model.visualize_documents(
docs=all_docs,
reduced_embeddings=reduced_embeddings_array,
custom_labels=True,
title="",
)
)
custom_labels = base_model.custom_labels_
topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
interactive_plot = datamapplot.create_interactive_plot(
reduced_embeddings_array,
topic_names_array,
hover_text=all_docs,
title=dataset,
sub_title=sub_title.replace(
"dataset",
f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
),
enable_search=True,
# TODO: Export data to .arrow and also serve it
inline_data=True,
# offline_data_prefix=dataset_clear_name,
initial_zoom_fraction=0.8,
)
html_content = str(interactive_plot)
html_file_path = f"{dataset_clear_name}.html"
with open(html_file_path, "w", encoding="utf-8") as html_file:
html_file.write(html_content)
repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
space_id = create_space_with_content(
api=api,
repo_id=repo_id,
dataset_id=dataset,
html_file_path=html_file_path,
plot_file_path=plot_png,
space_card=SPACE_REPO_CARD_CONTENT,
token=HF_TOKEN,
)
space_link = f"https://huggingface.co/spaces/{space_id}"
yield (
gr.Accordion(open=False),
topics_info,
topic_plot,
gr.Label(
{f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
),
f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
)
del reduce_umap_model, all_docs, reduced_embeddings_list
del (
base_model,
all_topics,
topic_info,
topic_names_array,
interactive_plot,
)
cuda.empty_cache()
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center;'>💠 Dataset Topic Discovery 🔭</h1>")
gr.HTML(
"<h3 style='text-align: center;'>Select a dataset and text column for topic modeling</h3>"
)
gr.HTML(
"<p style='text-align: center; color:orange;'>⚠ This space is in progress, and we're actively working on it, so you might find some bugs! Please report any issues you have in the Community tab to help us make it better for all.</p>"
)
data_details_accordion = gr.Accordion("Data details", open=True)
with data_details_accordion:
with gr.Row():
with gr.Column(scale=3):
dataset_name = HuggingfaceHubSearch(
label="Hub Dataset ID",
placeholder="Search for dataset id on Huggingface",
search_type="dataset",
)
subset_dropdown = gr.Dropdown(label="Subset", visible=False)
split_dropdown = gr.Dropdown(label="Split", visible=False)
with gr.Accordion("Dataset preview", open=False):
@gr.render(inputs=[dataset_name, subset_dropdown, split_dropdown])
def embed(name, subset, split):
html_code = f"""
<iframe
src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}"
frameborder="0"
width="100%"
height="600px"
></iframe>
"""
return gr.HTML(value=html_code)
with gr.Row():
text_column_dropdown = gr.Dropdown(label="Text column name")
plot_type_radio = gr.Radio(
["DataMapPlot", "Plotly"],
value="DataMapPlot",
label="Choose the plot type",
interactive=True,
)
generate_button = gr.Button("Generate Topics", variant="primary")
gr.Markdown("## Data map")
full_topics_generation_label = gr.Label(visible=False, show_label=False)
open_space_label = gr.Markdown()
topics_plot = gr.Plot()
with gr.Accordion("Topics Info", open=False):
topics_df = gr.DataFrame(interactive=False, visible=True)
gr.HTML(
f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
)
gr.Markdown(
"_Powered by [bertopic](https://maartengr.github.io/BERTopic/index.html) [datamapplot](https://datamapplot.readthedocs.io/en/latest/) and [duckdb](https://duckdb.org/)_"
)
generate_button.click(
generate_topics,
inputs=[
dataset_name,
subset_dropdown,
split_dropdown,
text_column_dropdown,
plot_type_radio,
],
outputs=[
data_details_accordion,
topics_df,
topics_plot,
full_topics_generation_label,
open_space_label,
],
)
def _resolve_dataset_selection(
dataset: str, default_subset: str, default_split: str, text_feature
):
if "/" not in dataset.strip().strip("/"):
return {
subset_dropdown: gr.Dropdown(visible=False),
split_dropdown: gr.Dropdown(visible=False),
text_column_dropdown: gr.Dropdown(label="Text column name"),
}
try:
info_resp = get_info(dataset)
except Exception:
return {
subset_dropdown: gr.Dropdown(visible=False),
split_dropdown: gr.Dropdown(visible=False),
text_column_dropdown: gr.Dropdown(label="Text column name"),
}
subsets: list[str] = list(info_resp)
subset = default_subset if default_subset in subsets else subsets[0]
splits: list[str] = list(info_resp[subset]["splits"])
split = default_split if default_split in splits else splits[0]
features = info_resp[subset]["features"]
def _is_string_feature(feature):
return isinstance(feature, dict) and feature.get("dtype") == "string"
text_features = [
feature_name
for feature_name, feature in features.items()
if _is_string_feature(feature)
]
if not text_feature:
return {
subset_dropdown: gr.Dropdown(
value=subset, choices=subsets, visible=len(subsets) > 1
),
split_dropdown: gr.Dropdown(
value=split, choices=splits, visible=len(splits) > 1
),
text_column_dropdown: gr.Dropdown(
choices=text_features,
label="Text column name",
),
}
return {
subset_dropdown: gr.Dropdown(
value=subset, choices=subsets, visible=len(subsets) > 1
),
split_dropdown: gr.Dropdown(
value=split, choices=splits, visible=len(splits) > 1
),
text_column_dropdown: gr.Dropdown(
choices=text_features, label="Text column name"
),
}
@dataset_name.change(
inputs=[dataset_name],
outputs=[
subset_dropdown,
split_dropdown,
text_column_dropdown,
],
)
def show_input_from_subset_dropdown(dataset: str) -> dict:
return _resolve_dataset_selection(
dataset, default_subset="default", default_split="train", text_feature=None
)
@subset_dropdown.change(
inputs=[dataset_name, subset_dropdown],
outputs=[
subset_dropdown,
split_dropdown,
text_column_dropdown,
],
)
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
return _resolve_dataset_selection(
dataset, default_subset=subset, default_split="train", text_feature=None
)
@split_dropdown.change(
inputs=[dataset_name, subset_dropdown, split_dropdown],
outputs=[
subset_dropdown,
split_dropdown,
text_column_dropdown,
],
)
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
return _resolve_dataset_selection(
dataset, default_subset=subset, default_split=split, text_feature=None
)
@text_column_dropdown.change(
inputs=[dataset_name, subset_dropdown, split_dropdown, text_column_dropdown],
outputs=[
subset_dropdown,
split_dropdown,
text_column_dropdown,
],
)
def show_input_from_text_column_dropdown(
dataset: str, subset: str, split: str, text_column
) -> dict:
return _resolve_dataset_selection(
dataset,
default_subset=subset,
default_split=split,
text_feature=text_column,
)
demo.launch()