SJTUCL's picture
Update app.py
7d8479e
raw
history blame
No virus
2.82 kB
import nltk
nltk.download('punkt')
import pandas as pd
import gradio as gr
from nltk import sent_tokenize
from transformers import pipeline
from gradio.themes.utils import red, green
detector = pipeline(task='text-classification', model='SJTU-CL/RoBERTa-large-ArguGPT-sent')
color_map = {
'0%': green.c400,
'10%': green.c300,
'20%': green.c200,
'30%': green.c100,
'40%': green.c50,
'50%': red.c50,
'60%': red.c100,
'70%': red.c200,
'80%': red.c300,
'90%': red.c400,
'100%': red.c500
}
def predict_doc(doc):
sents = sent_tokenize(doc)
data = {'sentence': [], 'label': [], 'score': []}
res = []
for sent in sents:
prob = predict_one_sent(sent)
data['sentence'].append(sent)
data['score'].append(round(prob, 4))
if prob <= 0.5:
data['label'].append('Human')
else: data['label'].append('Machine')
if prob < 0.1: label = '0%'
elif prob < 0.2: label = '10%'
elif prob < 0.3: label = '20%'
elif prob < 0.4: label = '30%'
elif prob < 0.5: label = '40%'
elif prob < 0.6: label = '50%'
elif prob < 0.7: label = '60%'
elif prob < 0.8: label = '70%'
elif prob < 0.9: label = '80%'
elif prob < 1: label = '90%'
else: label = '100%'
res.append((sent, label))
df = pd.DataFrame(data)
df.to_csv('result.csv')
overall_score = df.score.mean()
sum_str = ''
if overall_score <= 0.5: overall_label = 'Human'
else: overall_label = 'Machine'
sum_str = f'The essay is probably written by {overall_label}. The probability of being generated by AI is {overall_score}'
return sum_str, res, df, 'result.csv'
def predict_one_sent(sent):
'''
convert to prob
LABEL_1, 0.66 -> 0.66
LABEL_0, 0.66 -> 0.34
'''
res = detector(sent)[0]
org_label, prob = res['label'], res['score']
if org_label == 'LABEL_0': prob = 1 - prob
return prob
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
text_in = gr.Textbox(
lines=5,
label='Essay input',
info='Please enter the essay in the textbox'
)
btn = gr.Button('Predict who writes this essay!')
sent_res = gr.Highlight(
label='Labeled Result'
).style(color_map=color_map)
with gr.Row():
summary = gr.Text(
label='Result summary'
)
csv_f = gr.File(
label='CSV file storing data with all sentences.'
)
tab = gr.DataFrame(
label='Table with Probability Score',
max_rows=100
)
btn.click(predict_doc, inputs=[text_in], outputs=[summary, sent_res, tab, csv_f], api_name='predict_doc')
demo.launch()