ccm's picture
Update main.py
f37ef58 verified
raw
history blame
4.33 kB
import gradio # for the interface
import transformers # to load an LLM
import sentence_transformers # to load an embedding model
import faiss # to create an index
import numpy # to work with vectors
import pandas # to work with pandas
import json # to work with JSON
import datasets # to load the dataset
import spaces # for GPU
import threading
# Load the dataset and convert to pandas
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
# Define the base URL for Google Scholar
SCHOLAR_URL = "https://scholar.google.com"
# Filter out any publications without an abstract
filter = [
'"abstract": null' in json.dumps(bibdict)
for bibdict in full_data["bib_dict"].values
]
data = full_data[~pandas.Series(filter)]
data.reset_index(inplace=True)
# Create a FAISS index for fast similarity search
metric = faiss.METRIC_INNER_PRODUCT
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
index = faiss.IndexFlatL2(len(data["embedding"][0]))
index.metric_type = metric
faiss.normalize_L2(vectors)
index.train(vectors)
index.add(vectors)
# Load the model for later use in embeddings
model = sentence_transformers.SentenceTransformer("allenai-specter")
# Define the search function
def search(query: str, k: int) -> tuple[str]:
query = numpy.expand_dims(model.encode(query), axis=0)
faiss.normalize_L2(query)
D, I = index.search(query, k)
top_five = data.loc[I[0]]
search_results = "You are an AI assistant who delights in helping people" \
+ "learn about research from the Design Research Collective. Here are" \
+ "several really cool abstracts:\n\n"
references = "\n\n## References\n\n"
for i in range(k):
search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
references += str(i+1) + ". [" + top_five["bib_dict"].values[i]["title"] + "]" \
+ "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ")\n"
search_results += "\nSummarize the above abstracts as you respond to the following query:"
print(search_results)
return search_results, references
# Create an LLM pipeline that we can send queries to
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct",
torch_dtype="auto",
device_map="auto"
)
def preprocess(message: str) -> tuple[str]:
"""Applies a preprocessing step to the user's message before the LLM receives it"""
block_search_results, formatted_search_results = search(message, 5)
return block_search_results + message, formatted_search_results
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
"""Applies a postprocessing step to the LLM's response before the user receives it"""
return response + bypass_from_preprocessing
def predict(message: str, history: list[str]) -> str:
"""This function is responsible for crafting a response"""
# Apply preprocessing
message, bypass = preprocess(message)
# This is some handling that is applied to the history variable to put it in a good format
if isinstance(history, list):
if len(history) > 0:
history = history[-1]
history_transformer_format = [
{"role": "assistant" if idx&1 else "user", "content": msg}
for idx, msg in enumerate(history)
] + [{"role": "user", "content": message}]
# Stream a response from pipe
text = tokenizer.apply_chat_template(
history_transformer_format,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt")
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=512
)
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
yield partial_message + bypass
# Create and run the gradio interface
gradio.ChatInterface(predict).launch(debug=True)