Spaces:
Running
on
Zero
Running
on
Zero
import gradio # for the interface | |
import transformers # to load an LLM | |
import sentence_transformers # to load an embedding model | |
import faiss # to create an index | |
import numpy # to work with vectors | |
import pandas # to work with pandas | |
import json # to work with JSON | |
import datasets # to load the dataset | |
import spaces # for GPU | |
import threading | |
# Load the dataset and convert to pandas | |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas() | |
# Define the base URL for Google Scholar | |
SCHOLAR_URL = "https://scholar.google.com" | |
# Filter out any publications without an abstract | |
filter = [ | |
'"abstract": null' in json.dumps(bibdict) | |
for bibdict in full_data["bib_dict"].values | |
] | |
data = full_data[~pandas.Series(filter)] | |
data.reset_index(inplace=True) | |
# Create a FAISS index for fast similarity search | |
metric = faiss.METRIC_INNER_PRODUCT | |
vectors = numpy.stack(data["embedding"].tolist(), axis=0) | |
index = faiss.IndexFlatL2(len(data["embedding"][0])) | |
index.metric_type = metric | |
faiss.normalize_L2(vectors) | |
index.train(vectors) | |
index.add(vectors) | |
# Load the model for later use in embeddings | |
model = sentence_transformers.SentenceTransformer("allenai-specter") | |
# Define the search function | |
def search(query: str, k: int) -> tuple[str]: | |
query = numpy.expand_dims(model.encode(query), axis=0) | |
faiss.normalize_L2(query) | |
D, I = index.search(query, k) | |
top_five = data.loc[I[0]] | |
search_results = "You are an AI assistant who delights in helping people" \ | |
+ "learn about research from the Design Research Collective. Here are" \ | |
+ "several really cool abstracts:\n\n" | |
references = "\n\n## References\n\n" | |
for i in range(k): | |
search_results += top_five["bib_dict"].values[i]["abstract"] + "\n" | |
references += str(i+1) + ". [" + top_five["bib_dict"].values[i]["title"] + "]" \ | |
+ "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ")\n" | |
search_results += "\nSummarize the above abstracts as you respond to the following query:" | |
print(search_results) | |
return search_results, references | |
# Create an LLM pipeline that we can send queries to | |
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2-0.5B-Instruct", | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
def preprocess(message: str) -> tuple[str]: | |
"""Applies a preprocessing step to the user's message before the LLM receives it""" | |
block_search_results, formatted_search_results = search(message, 5) | |
return block_search_results + message, formatted_search_results | |
def postprocess(response: str, bypass_from_preprocessing: str) -> str: | |
"""Applies a postprocessing step to the LLM's response before the user receives it""" | |
return response + bypass_from_preprocessing | |
def predict(message: str, history: list[str]) -> str: | |
"""This function is responsible for crafting a response""" | |
# Apply preprocessing | |
message, bypass = preprocess(message) | |
# This is some handling that is applied to the history variable to put it in a good format | |
if isinstance(history, list): | |
if len(history) > 0: | |
history = history[-1] | |
history_transformer_format = [ | |
{"role": "assistant" if idx&1 else "user", "content": msg} | |
for idx, msg in enumerate(history) | |
] + [{"role": "user", "content": message}] | |
# Stream a response from pipe | |
text = tokenizer.apply_chat_template( | |
history_transformer_format, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt") | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=512 | |
) | |
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_message = "" | |
for new_token in streamer: | |
if new_token != '<': | |
partial_message += new_token | |
yield partial_message | |
yield partial_message + bypass | |
# Create and run the gradio interface | |
gradio.ChatInterface(predict).launch(debug=True) |