Petr Tsvetkov
Add labels highlighting the total # of samples and the commit message textbox
039521b
raw
history blame
No virus
7.44 kB
import json
import os
import random
import uuid
from datetime import datetime
from difflib import ndiff
import gradio as gr
from data_loader import load_data
from hf_dataset_saver_builder import get_dataset_saver
HF_TOKEN = os.environ.get('HF_REWRITING_TOKEN')
HF_DATASET = os.environ.get('HF_REWRITING_DATASET')
data = load_data()
n_samples = len(data)
saver = get_dataset_saver(HF_TOKEN, HF_DATASET, private=True, separate_dirs=True)
def convert_diff_to_unified(diff_string):
diff = json.loads(diff_string)
result = "\n".join(
[
f'--- {modified_file["old_path"]}\n'
f'+++ {modified_file["new_path"]}\n'
f'{modified_file["diff"]}'
for modified_file in diff
]
)
return result
def get_diff2html_view(raw_diff):
html = f"""
<div style='width:100%; height:1400px; overflow:auto; position: relative'>
<div id='diff-raw' hidden>{raw_diff}</div>
<div class="d2h-view-wrapper">
<div id='diff-view'></div>
</div>
</div>
"""
return html
def get_github_link_md(repo, hash):
return f'[See the commit on Github](https://github.com/{repo}/commit/{hash})'
def char_diff_obj(change_type, pos, character, timestamp):
return {"t": change_type, "p": pos, "c": character, "ts": timestamp}
def update_commit_view(sample_ind):
if sample_ind >= n_samples:
return None
record = data[sample_ind]
diff_view = get_diff2html_view(convert_diff_to_unified(record['mods']))
repo_val = record['repo']
hash_val = record['hash']
github_link_md = get_github_link_md(repo_val, hash_val)
diff_loaded_timestamp = datetime.now().isoformat()
summary_md = f"{record['summary']}"
commit_message = record['prediction']
commit_message_start = commit_message
commit_message_prev = commit_message
commit_message_history = []
return (
github_link_md, diff_view, repo_val, hash_val, diff_loaded_timestamp, summary_md,
commit_message_start, commit_message, commit_message_prev, commit_message_history)
def next_sample(current_sample_ind, shuffled_idx):
if current_sample_ind == n_samples:
return None
current_sample_ind += 1
updated_view = update_commit_view(shuffled_idx[current_sample_ind])
return (current_sample_ind,) + updated_view
with open("head.html") as head_file:
head_html = head_file.read()
force_light_theme_js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
with gr.Blocks(theme=gr.themes.Soft(), head=head_html, css="style_overrides.css",
js=force_light_theme_js_func) as application:
repo_val = gr.Textbox(interactive=False, label='repo', visible=False)
hash_val = gr.Textbox(interactive=False, label='hash', visible=False)
shuffled_idx_val = gr.JSON(visible=False)
with gr.Row():
with gr.Accordion("Help"):
with open("survey_guide.md") as content_file:
gr.Markdown(content_file.read())
with gr.Row():
current_sample_sld = gr.Slider(minimum=0, maximum=n_samples, step=1,
value=0,
interactive=False,
label='sample_ind',
info=f"Samples labeled/skipped",
show_label=False,
container=False,
scale=5)
with gr.Column(scale=1):
gr.Markdown(value=f"#### Total number of samples: {n_samples}")
with gr.Column(scale=1):
skip_btn = gr.Button("Skip the current sample")
with gr.Row():
with gr.Column(scale=2):
github_link = gr.Markdown()
diff_view = gr.HTML()
with gr.Column(scale=1):
with gr.Accordion("Commit summary (AI generated)", open=False):
commit_summary = gr.Markdown()
commit_msg_start = gr.TextArea(label="commit_msg_start", visible=False)
gr.Markdown(value=f"#### Please, edit the message in the text box below")
commit_msg = gr.TextArea(label="commit_msg_end", show_label=False,
info="Commit message (can be scrollable)")
commit_msg_prev = gr.TextArea(visible=False)
commit_msg_history = gr.JSON(label="commit_msg_history", visible=False)
submit_btn = gr.Button("Submit")
session_val = gr.Textbox(info='Session', interactive=False, container=True, show_label=False,
label='session')
with gr.Row(visible=False):
sample_loaded_timestamp = gr.Textbox(info="Sample loaded", label='loaded_ts', interactive=False,
container=True, show_label=False)
now_timestamp = gr.Textbox(info="Current time",
interactive=False, container=True, show_label=False,
value=lambda: datetime.now().isoformat(), every=0.1,
label='submitted_ts')
commit_view = [
github_link,
diff_view,
repo_val,
hash_val,
sample_loaded_timestamp,
commit_summary,
commit_msg_start,
commit_msg,
commit_msg_prev,
commit_msg_history
]
feedback_metadata = [
session_val,
repo_val,
hash_val,
sample_loaded_timestamp,
now_timestamp
]
feedback_form = [
commit_msg_start,
commit_msg,
commit_msg_history
]
saver.setup([current_sample_sld] + feedback_metadata + feedback_form, "feedback")
skip_btn.click(next_sample, inputs=[current_sample_sld, shuffled_idx_val],
outputs=[current_sample_sld] + commit_view)
def submit(current_sample, shuffled_idx, *args):
saver.flag((current_sample,) + args)
return next_sample(current_sample, shuffled_idx)
submit_btn.click(
submit,
inputs=[current_sample_sld, shuffled_idx_val] + feedback_metadata + feedback_form,
outputs=[current_sample_sld] + commit_view
)
def on_commit_msg_changed(message, prev_message, history):
timestamp = datetime.now().isoformat()
for i, s in enumerate(ndiff(prev_message, message)):
diff = char_diff_obj(s[0], i, s[-1], timestamp)
if diff['t'] in ('+', '-'):
history.append(diff)
return message, history
commit_msg.change(on_commit_msg_changed, inputs=[commit_msg, commit_msg_prev, commit_msg_history],
outputs=[commit_msg_prev, commit_msg_history])
def init_session(current_sample):
session = str(uuid.uuid4())
shuffled_idx = list(range(n_samples))
random.shuffle(shuffled_idx)
return (session, shuffled_idx) + update_commit_view(shuffled_idx[current_sample])
application.load(init_session,
inputs=[current_sample_sld],
outputs=[session_val, shuffled_idx_val] + commit_view, )
application.launch()