Spaces:
Runtime error
Runtime error
import re | |
import os | |
import logging | |
import gradio as gr | |
from typing import Set, List, Tuple | |
from huggingface_hub import InferenceClient | |
from langchain_openai import AzureChatOpenAI | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import SimpleSequentialChain | |
from langchain.chains import LLMSummarizationCheckerChain | |
# huggingface_key = os.getenv('HUGGINGFACE_KEY') | |
# print(huggingface_key) | |
# login(huggingface_key) # Huggingface api token | |
# Configure logging | |
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
class FactChecking: | |
def __init__(self): | |
self.llm = AzureChatOpenAI( | |
azure_deployment = "GPT-3" | |
) | |
self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
def format_prompt(self, question: str) -> str: | |
""" | |
Formats the input question into a specific structure for text generation. | |
Args: | |
question (str): The user's question to be formatted. | |
Returns: | |
str: The formatted prompt including instructions and the question. | |
""" | |
# Combine the instruction template with the user's question | |
prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]" | |
prompt1 = f"[INST] {question} [/INST]" | |
return prompt+prompt1 | |
def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0): | |
""" | |
Generates a response to the given prompt using text generation parameters. | |
Args: | |
prompt (str): The user's question. | |
temperature (float): Controls randomness in response generation. | |
max_new_tokens (int): The maximum number of tokens to generate. | |
top_p (float): Nucleus sampling parameter controlling diversity. | |
repetition_penalty (float): Penalty for repeating tokens. | |
Returns: | |
str: The generated response to the input prompt. | |
""" | |
# Adjust temperature and top_p values within acceptable ranges | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
# Simulating a call to a client's text generation API | |
formatted_prompt =self.format_prompt(prompt) | |
stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return output.replace("</s>","") | |
def find_different_sentences(self,chain_answer,llm_answer): | |
try: | |
truth_values = [sentence.strip().split(' (')[1][:-1] for sentence in chain_answer.split('\n\n')] | |
except: | |
print("single new line presenting") | |
try: | |
# Extracting the truth values from chain_answer | |
truth_values = [sentence.strip().split(' (')[1][:-1] for sentence in chain_answer.split('\n')] | |
except: | |
print("two new lines presenting") | |
tags = [] | |
for tag in truth_values: | |
if "True" in tag: | |
tags.append("factual") | |
else: | |
tags.append("hallucinated") | |
# Splitting llm_answer into sentences | |
llm_sentences = llm_answer.split('. ') | |
# Initializing an empty list to store tagged sentences | |
tagged_sentences = [] | |
# Mapping the truth values to sentences in llm_answer | |
for sentence, truth_value in zip(llm_sentences, tags): | |
# Extracting the sentence without the truth value | |
sentence_text = sentence.split(' (')[0] | |
# Appending the sentence with its truth value | |
tagged_sentences.append(((sentence_text+"."),(truth_value))) | |
return tagged_sentences | |
def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]: | |
""" | |
Finds hallucinated sentences in response to a given question. | |
Args: | |
question (str): The input question. | |
Returns: | |
Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences. | |
""" | |
try: | |
# Generate initial response using contract generator | |
mixtral_response = self.mixtral_response(question) | |
template = """Given some text, extract a list of facts from the text. | |
Format your output as a bulleted list. | |
Text: | |
{question} | |
Facts:""" | |
prompt_template = PromptTemplate(input_variables=["question"], template=template) | |
question_chain = LLMChain(llm=self.llm, prompt=prompt_template) | |
template = """You are an expert fact checker. You have been hired by a major news organization to fact check a very important story. | |
Here is a bullet point list of facts: | |
{statement} | |
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined". | |
If the fact is false, explain why.""" | |
prompt_template = PromptTemplate(input_variables=["statement"], template=template) | |
assumptions_chain = LLMChain(llm=self.llm, prompt=prompt_template) | |
extra_template = f" Original Summary:{mixtral_response} Using these checked assertions to write the original summary with true or false in sentence wised. For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output 'Undetermined'.***format: sentence (True or False) in braces.***" | |
template = """Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction. | |
Checked Assertions: | |
{assertions} | |
""" | |
template += extra_template | |
prompt_template = PromptTemplate(input_variables=["assertions"], template=template) | |
answer_chain = LLMChain(llm=self.llm, prompt=prompt_template) | |
overall_chain = SimpleSequentialChain(chains=[question_chain,assumptions_chain,answer_chain], verbose=True) | |
answer = overall_chain.run(mixtral_response) | |
# Find different sentences between original result and fact checking result | |
prediction_list = self.find_different_sentences(answer,mixtral_response) | |
# prediction_list += generated_words | |
# Return the original result and list of hallucinated sentences | |
return mixtral_response,prediction_list,answer | |
except Exception as e: | |
print(f"Error occurred in find_hallucinatted_sentence: {e}") | |
return "", [] | |
def interface(self): | |
css=""".gradio-container {background: rgb(157,228,255); | |
background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<center><h1 style="color:#fff">Detect Hallucination</h1></center>""") | |
with gr.Row(): | |
question = gr.Textbox(label="Question") | |
with gr.Row(): | |
button = gr.Button(value="Submit") | |
with gr.Row(): | |
mixtral_response = gr.Textbox(label="llm answer") | |
with gr.Row(): | |
fact_checking_result = gr.Textbox(label="hallucinated detection result") | |
with gr.Row(): | |
highlighted_prediction = gr.HighlightedText( | |
label="Sentence Hallucination detection", | |
combine_adjacent=True, | |
color_map={"hallucinated": "red", "factual": "green"}, | |
show_legend=True) | |
button.click(self.find_hallucinatted_sentence,question,[mixtral_response,highlighted_prediction,fact_checking_result]) | |
demo.launch() | |
hallucination_detection = FactChecking() | |
hallucination_detection.interface() |