ccm's picture
Update main.py
a1a6086 verified
raw
history blame
No virus
6.49 kB
import json # to work with JSON
import threading # for threading
import time # for better HCI
import datasets # to load the dataset
import faiss # to create an index
import gradio # for the interface
import numpy # to work with vectors
import pandas # to work with pandas
import sentence_transformers # to load an embedding model
import spaces # for GPU
import transformers # to load an LLM
# Constants
GREETING = (
"Howdy! I'm an AI agent that uses a [retrieval-augmented generation]("
"https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the "
"[Design Research Collective](https://cmudrc.github.io/). And the best part is that I always cite my sources! What"
" can I tell you about today?"
)
EXAMPLE_QUERIES = [
"Tell me about new research at the intersection of additive manufacturing and machine learning",
"What is a physics-informed neural network and what can it be used for?",
"What can agent-based models do about climate change?",
]
EMBEDDING_MODEL_NAME = "allenai-specter"
LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
# Load the dataset and convert to pandas
data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
# Filter out any publications without an abstract
abstract_is_null = [
'"abstract": null' in json.dumps(bibdict) for bibdict in data["bib_dict"].values
]
data = data[~pandas.Series(abstract_is_null)]
data.reset_index(inplace=True)
# Create a FAISS index for fast similarity search
metric = faiss.METRIC_INNER_PRODUCT
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
index = faiss.IndexFlatL2(len(data["embedding"][0]))
index.metric_type = metric
faiss.normalize_L2(vectors)
index.train(vectors)
index.add(vectors)
# Load the model for later use in embeddings
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
def search(query: str, k: int) -> tuple[str, str]:
"""
Searches the dataset for the top k most relevant papers to the query
Args:
query (str): The user's query
k (int): The number of results to return
Returns:
tuple[str, str]: A tuple containing the search results and references
"""
query = numpy.expand_dims(model.encode(query), axis=0)
faiss.normalize_L2(query)
D, I = index.search(query, k)
top_five = data.loc[I[0]]
search_results = (
"You are an AI assistant who delights in helping people learn about research from the Design "
"Research Collective. Here are several abstracts from really cool, and really relevant, "
"papers:\n\n"
)
references = "\n\n## References\n\n"
for i in range(k):
search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
references += (
str(i + 1)
+ ". "
+ ", ".join(
[
author.split(" ")[-1]
for author in top_five["bib_dict"]
.values[i]["author"]
.split(" and ")
]
)
+ ". ("
+ str(int(top_five["bib_dict"].values[i]["pub_year"]))
+ "). ["
+ top_five["bib_dict"].values[i]["title"]
+ "]"
+ "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view="
+ top_five["author_pub_id"].values[i]
+ ").\n"
)
search_results += (
"\nUsing the information provided above, respond to this query: "
)
return search_results, references
# Create an LLM pipeline that we can send queries to
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
streamer = transformers.TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
)
def preprocess(message: str) -> tuple[str, str]:
"""
Applies a preprocessing step to the user's message before the LLM receives it
Args:
message (str): The user's message
Returns:
tuple[str, str]: A tuple containing the preprocessed message and a bypass variable
"""
block_search_results, formatted_search_results = search(message, 5)
return block_search_results + message, formatted_search_results
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
"""
Applies a postprocessing step to the LLM's response before the user receives it
Args:
response (str): The LLM's response
bypass_from_preprocessing (str): The bypass variable from the preprocessing step
Returns:
str: The postprocessed response
"""
return response + bypass_from_preprocessing
@spaces.GPU
def reply(message: str, history: list[str]) -> str:
"""
This function is responsible for crafting a response
Args:
message (str): The user's message
history (list[str]): The conversation history
Returns:
str: The AI's response
"""
# Apply preprocessing
message, bypass = preprocess(message)
# This is some handling that is applied to the history variable to put it in a good format
history_transformer_format = [
{"role": role, "content": message_pair[idx]}
for message_pair in history
for idx, role in enumerate(["user", "assistant"])
if message_pair[idx] is not None
] + [{"role": "user", "content": message}]
# Stream a response from pipe
text = tokenizer.apply_chat_template(
history_transformer_format, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != "<":
partial_message += new_token
time.sleep(0.05)
yield partial_message
yield partial_message + bypass
# Create and run the gradio interface
gradio.ChatInterface(
reply,
examples=EXAMPLE_QUERIES,
chatbot=gradio.Chatbot(
show_label=False, show_copy_button=True, value=[[None, GREETING]]
),
retry_btn=None,
undo_btn=None,
clear_btn=None,
cache_examples=True,
fill_height=True,
).launch(debug=True)